min-dalle-test/min_dalle.ipynb

224 lines
2.5 MiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-04 14:21:12 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-04 14:21:12 +00:00
"execution_count": 2,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-03 19:35:33 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
2022-07-03 22:40:27 +00:00
},
"id": "ix_xt4X1_6F4",
2022-07-04 14:21:12 +00:00
"outputId": "36db2d18-c7d0-4603-e0dd-b289e7a99592"
2022-06-28 00:58:17 +00:00
},
2022-07-03 19:35:33 +00:00
"outputs": [
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-07-03 19:35:33 +00:00
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2022-07-04 14:21:12 +00:00
"Collecting min-dalle\n",
" Downloading min-dalle-0.2.17.tar.gz (10 kB)\n",
2022-07-03 19:35:33 +00:00
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.11.0+cu113)\n",
2022-07-04 14:21:12 +00:00
"Requirement already satisfied: typing_extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (4.1.1)\n",
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.21.6)\n",
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (2.23.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2022.6.15)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (3.0.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2.10)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (1.24.3)\n",
"Building wheels for collected packages: min-dalle\n",
" Building wheel for min-dalle (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for min-dalle: filename=min_dalle-0.2.17-py3-none-any.whl size=10987 sha256=b1a239f5428d6ccc757f47094ce1e1986d3900a754cc7f3c0a362c4ffdba7162\n",
" Stored in directory: /root/.cache/pip/wheels/d4/8b/2d/14470dada5426179003743c4f0b8bfc9804c68c21a9275eb62\n",
"Successfully built min-dalle\n",
"Installing collected packages: min-dalle\n",
"Successfully installed min-dalle-0.2.17\n",
"Mon Jul 4 14:16:04 2022 \n",
2022-07-03 19:35:33 +00:00
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
2022-07-04 14:21:12 +00:00
"| N/A 35C P0 26W / 250W | 0MiB / 16280MiB | 0% Default |\n",
2022-07-03 19:35:33 +00:00
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-04 14:21:12 +00:00
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-04 14:21:12 +00:00
"outputId": "0afeaa09-c0df-4c2f-bdb2-19cebbd86281"
},
"outputs": [
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
"text": [
"initializing MinDalle\n",
2022-07-02 13:31:20 +00:00
"downloading tokenizer params\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
2022-07-02 13:31:20 +00:00
"downloading encoder params\n",
"initializing DalleBartEncoder\n",
2022-07-02 13:31:20 +00:00
"downloading decoder params\n",
"initializing DalleBartDecoder\n",
2022-07-02 13:31:20 +00:00
"downloading detokenizer params\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
2022-07-04 14:21:12 +00:00
"Note: reduce the grid size if you run out of GPU memory. 4x4 has been tested to work on T4 and P100"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-04 14:21:12 +00:00
"execution_count": 4,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-04 14:21:12 +00:00
"height": 1000
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-04 14:21:12 +00:00
"outputId": "7e711399-0b3f-48ee-808e-27a51509829b"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-06-28 00:58:17 +00:00
"text": [
"tokenizing text\n",
2022-07-03 19:35:33 +00:00
"['Ġdali']\n",
"['Ġpainting']\n",
2022-07-01 01:23:15 +00:00
"['Ġof']\n",
2022-07-03 19:35:33 +00:00
"['Ġwal', 'le']\n",
"text tokens [0, 21853, 1545, 111, 563, 92, 2]\n",
2022-06-28 00:58:17 +00:00
"encoding text tokens\n",
"sampling image tokens\n",
"detokenizing image\n"
]
},
{
2022-07-04 14:21:12 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
2022-07-04 14:21:12 +00:00
"<PIL.Image.Image image mode=RGB size=1024x1024 at 0x7F33C65140D0>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABAAAAAQACAIAAADwf7zUAAEAAElEQVR4nIz9efg123UWBr5rrT1U1Rl+0zfe7w66o3QlWdZgW7awIw/I2DjgQPDAE/MEB3AHCAECIWkgNCRPGJoGGrppujFDIDgGM9mxwbKxjSVjWbJGa7yS7r268/3G33TOqao9rdV/fN+1BYR+ev9V56laa9eps8/aa3jfVYS/9o/x+RNUj8Bwht0GnmEKq2AGAUYohuhRC2JAzYCBPTKBBK6BCNyhZniHWtAqmoIYNcEEQ0RrIIE1NIAF84jBoRSIgBwKgQUMNAMT6gwGCKgCx6gJxLAAR7AGEJBRKwpDJ/gGcXAOCiQFGciACgY8kBP6Ducjeo+WESLYoDNQUBpEMEfEAdzQe2Djw7QsN19/EFbzZ7/7+35reumFxXDh1sbOptaYmFZdDD3xcrHXmVvQat0v1gHRF6Mbzz//gTy/8E/+wc9+aeSb26tPPvHmb/nqd73nP/raw4P5+RdeOCvn40J+5qOzDQ889oYnv/Mte2/ExYb5GM/+o185+QP/yR/Cc8/jCF/5Nctvfvdb9mS72Of9g/tW1w431Z/d6K+/srn16naxPPBtEwda+tjmentXP/PS5sWXd6cj/NETfsGrB+4/c8O0uO++r3wUD77xdnHXj5/DcYAlpB1wBmqYz5Em2AidYYRxCx8wZkwFfUTNCD1SgQSAIAHCmEcIwxHIIIRxhu/BDTBkRSeoCgFKhRA2M6JAGozADI5IDc6DFXmCD+CGKYEd/NCxIc1NpGoMrELI2qpVZAMxdPbOmjE5aZUBYhIWqmkXQkQZYUadTzvqguuQaqvbwqFzaloV0CLMDcA4YxBMBVYRBmSFj4gBzfty8PaHLs9jiLnthF3wedwFJ461WSJiM4gxKcTLbp6XXUCdMqAmrnHgxlSaOE/Bytx7SbVMraqiN/Y6w7EykRoRazM2qPE4bQ6WXmuuRCIRmaDGHVLKi+BRRiWrRqjsrAbXlJgoqqVAUqyNLavqgOB0Mo8mQUBqgMJUjWQ3nx+to+U5M0r2JXFNidzMIo5pcGWcpqqGVuZUBi97gwMQOk4ppzQvOz+PYx+dcGNr/UI8224351SiYyZmZB/UBwa0zFozM1ku6fCwBxereZxOFv0RtVjmLH4Ck2fnY1svVymPc944cO+82AwQRxE2gbaqROJpwXIcujGNgaRnWzBCU4NYaZtF78gmpZyzt7wPJe52RM6a99FIW9U8t5FajNQBszKpkFkTkpQbc2Tqmx2zP52nVjVo2dPSTakaqyKzq+zSuJtv3arz5IMLrldtIPL9OoybebNLJOg74iYtzUa6OFixqDWrUyMnecZyz8FP27MCcmILasvNpkhg38M0hahzGk/P2Ms6CIelMXkzFwaHYilVP7A2CzxY3dU2hcVKvKAUq6SQaZS4qKEvrermfBvDRaohzZVEQyfsWMQ2283ZBkRh1Q8kzQvlUg2S5twH1JqIhgQbFvnStb2//vRFJMOsEKAXnG/RCgZGUbADAkjADBBaQdfBCqYZ/YDZ4Bw4wwAX0SpAIAMBY4N3oBlTQhdABGWoohJ8xHiKZUTLyA3Wge+aGgIUBlgFK9DQBLlAFBIRAkpFEJABBi0ohJYRDYEAAwxzBgs8wzeIAgnzDsbwAAWEBSbAR8RDDoOCsStI9Zu+4de/5b77Ll2+bGE1+SvAdOe4zmRz6M532+PTumGX8/Rq7rZtQprhBBLcYrWwOc11wW5h3s/T7tbJAY+Xh0Q6zm2eVYShxsraDLmhmGsKESrz7LQt+9U64uSZL936yM8sb372T/yZH3jHN36Dpv7wgQcI2aY4Yzo+3uQ0VXN7RwuVpY7jZtNun03nu9sn2xF+ODu7Xqdj79ax0f7lJ4bDa5m8tU0oleeT7XzngTdeeeWZT3/iw5/41Id+Gbr95BdfycdAxZcNwgO/+ff8pe9/+M3f9tLJS7/8/Jeevj2eji9gSqAN5hGbc9QtvCGNODnHeoXgsd1gt0M/wDmYwuZ7P1xWIAIVUqGKBqhACexAhjQhRNgEzTBBFhgjAOLBEWmCu6ukoRGGgDLe26adwAi1gRxcxO6O2x/qNKEq0IEjDIiCoiAHSygTXESLIHN9bblZNcDBEZqCDSQggkXEC3/uf/rebe22N04KaiB1nnalOCfVmGPMBSjahQjhVEt0bES3dxsrtHShd96JNu/ZAmkRQSYc76Y6laNF7EPd5VYVnh3gK9TABIzj2cXDvuWkzmtlLqQmFGlXqjcOrE1zdpw2tSNdRzC3XSLvo5Br7FS41lxz64c41rTTdjynq+t1TvPgPQsXpaJ5WzPN9WJwF5crV7j3HCQvog42opRLl9YPXru03dSPfOyTc9bv/I7ffHpriou1aubiu4VDZe9cq0SBGstUp2yS6lbbq/n81jzfFBFB7MMqDqFHKWcvx9jiOu5un0toy/7iEI8IqzY73zP5Rq6c3Dzbv9Y3nJyV6be/+b/98NP4/z3e8A1PmrOyPf2qd77te7//e2o9Prm9/Yf/2z+9dePGnRu723c25ZXXLhW8+z14+Vkc3odv+o1v+D0/8Nsv7X3zC5tedBuLZzknt2UzrduXXt2ezzdZ5k258eIXPv4P/+qnP/ThL5vyvit/4G/+N9Pi7a9stufnZ7zc25men+76dSyqrabNPM+btAj9wbK/s92YEJlGL2a82268uYW46Cpznccch76kmpL6YZGt5JJayT7Irs5pRMddJOc8jZAu9HmeQPCdu3O6I6K9jvO4Odmmxbp37FUpWYFRyTUGWq7dZjulVGABcwV56jorCh9QZjDQRdzeYYhwDWZIilRBDAPEY1gADRzJXXY4cwgrhAQfsR2x2kNg5Iw6wwgEgBAEtUE6mAczuKE6LBeogCY4Rm0QRlE4AjyWK+SMFoECVcABgAl6h1zQdxACHJigjKVHUeQKF0CANEjDxFjvwRymcwAwuxd1sKF5dAP2PFJE2wEEEFTRR7iAmsAGIWhCF9Av0F+AzmiKXMEECkCFADAsergAUgwLLFdliQdXDz5im1934dLRrWcffeiR51HuzGenY5nnVe/9eq+7b7Far1cHi8Vh2I9ZSqs95zm1h+5/+MVnX73vscc+/tk7r4xp9fQzl97zTQ8+9OB+dFcuPPbPfuoj/+v/8hPv/ZmX8cwd7NFf+bo3/Yn/8tu+79e/Ky6WT/3yh4EdDoEreOhBe/MT3SNXL5s0Cgdpgx5Xr15bff3RnL4SZzPamFwv57vjk7OU+/ioC/dflNPT+qXz07Op3Pjc9TtB+/uuvDx9qH3i6uwiLq6xfwHsUBw6gnaIBCwwnqMUlIIQUYE846jDuENcQA39EkbwHYSQZoQFxIEKqKAUrA4Wvt9NZygFLsAAEGoBCNqwXJJqz7UU4yAtVecCkZtyYj+sVM7KKZxB/F5cXDC7lXVkH4wCgWotpGgz4uKQ+jkZcWmNAXaqC99pk9Np511YFZ4dqc1TnbC8fC0OJ7uTDDgmRwKlwARGddp2E6IXcW0IgKFFeIV3qIohFHc+hbdZOvMDeXZjyov1qo9s4xiIK+AJaKhGKZWD2C3BW+mq2wXrL2NJNZ/xBFhnZl76RnBSojEtjzJN5ppN3pgdE6OSsriUy956GQUmPrvMlQ665VTnU03Lbjiq/jzWRFu04eJw6Gve2ImDkDZlWVZfhHNfgP7KtJoRxnDuDZ48O01JPS93ZTzaX/fE2rvqzxa2vMgHOe8Kt1TLQJLK+YXD9Z35XCvd3y3ruNOWvCisuWV//30H4zxevLQnpbaSB+9qHUG2utCbkSPnlEtJDaPWRMQx2MIvxjLK/r6l1rEl2/UHYdHvd9YjzVUyNYkIBbthEkeCBR90q+G8VdpVSRERlR2rmSH2ZGrJedZAnvulKwOyN/jciqNFn7DTOwlbR7Qf15I1UyvKjlclzUJKXJvr16tD2TUFz5q8darkjBbwDRGENE5MI6yF4RKXQXPXgsxUqapTOy4nu3b70sX9YPs8I1EuFd7HnHMIw6XDLllbLTqZasmTcqYiCidi0pP5MMUdo7Zq/
2022-06-28 00:58:17 +00:00
},
2022-07-04 14:21:12 +00:00
"metadata": {}
2022-06-29 17:55:23 +00:00
},
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-04 14:21:12 +00:00
"CPU times: user 47.4 s, sys: 78.2 ms, total: 47.5 s\n",
"Wall time: 48.1 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-04 14:21:12 +00:00
"grid_size = 4 #@param {type:\"integer\"}\n",
2022-07-04 12:05:55 +00:00
"seed = -1 #@param {type:\"integer\"}\n",
2022-06-28 15:05:59 +00:00
"\n",
"display(model.generate_image(text, seed, grid_size))"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-04 14:21:12 +00:00
"provenance": [],
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-04 14:21:12 +00:00
}