min-dalle-test/min_dalle.ipynb

229 lines
645 KiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-04 14:21:12 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-04 21:27:02 +00:00
"execution_count": 3,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-04 20:30:39 +00:00
"id": "ix_xt4X1_6F4",
2022-07-04 21:27:02 +00:00
"outputId": "918d48d7-9d8e-48fa-ab32-9fbc142e0074",
2022-07-04 20:30:39 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
}
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:30:39 +00:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2022-07-04 21:27:02 +00:00
"Collecting min-dalle==0.2.22\n",
" Downloading min-dalle-0.2.22.tar.gz (11 kB)\n",
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.2.22) (1.11.0+cu113)\n",
"Requirement already satisfied: typing_extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.2.22) (4.1.1)\n",
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.2.22) (1.21.6)\n",
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.2.22) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.2.22) (2.23.0)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.2.22) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.2.22) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.2.22) (2.10)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.2.22) (2022.6.15)\n",
2022-07-04 20:30:39 +00:00
"Building wheels for collected packages: min-dalle\n",
" Building wheel for min-dalle (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
2022-07-04 21:27:02 +00:00
" Created wheel for min-dalle: filename=min_dalle-0.2.22-py3-none-any.whl size=11263 sha256=84726f3fdf87ac6ccf6731fefadcc29154efee2a26020b771dcd08d4724ccf01\n",
" Stored in directory: /root/.cache/pip/wheels/7b/db/df/ec2e6cb890f0f527e178401b3f2fecaa7106748f2bdbc69bb9\n",
2022-07-04 20:30:39 +00:00
"Successfully built min-dalle\n",
"Installing collected packages: min-dalle\n",
2022-07-04 21:27:02 +00:00
"Successfully installed min-dalle-0.2.22\n",
"Mon Jul 4 21:23:14 2022 \n",
2022-07-04 20:30:39 +00:00
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
2022-07-04 21:27:02 +00:00
"| N/A 37C P0 28W / 250W | 0MiB / 16280MiB | 0% Default |\n",
2022-07-04 20:30:39 +00:00
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-04 21:27:02 +00:00
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-04 21:27:02 +00:00
"outputId": "e5df04d8-bae6-4d00-acd8-eeb908d2a54f"
},
"outputs": [
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
"text": [
"initializing MinDalle\n",
2022-07-04 20:30:39 +00:00
"downloading tokenizer params\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
2022-07-04 20:30:39 +00:00
"downloading encoder params\n",
"initializing DalleBartEncoder\n",
2022-07-04 20:30:39 +00:00
"downloading decoder params\n",
"initializing DalleBartDecoder\n",
2022-07-04 20:30:39 +00:00
"downloading detokenizer params\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-04 20:06:14 +00:00
"from PIL import Image\n",
"from IPython.display import update_display\n",
"import numpy\n",
2022-07-04 21:27:02 +00:00
"from math import log2\n",
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
"Note: reduce the grid size if you run out of GPU memory. 4x4 has been tested to work on T4 and P100 (with intermediate_image_count = 1)"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-04 21:27:02 +00:00
"execution_count": 5,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-04 20:30:39 +00:00
"height": 563
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-04 21:27:02 +00:00
"outputId": "0213d6b2-a343-42d8-df03-7dc5dc768584"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-04 14:21:12 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
2022-07-04 21:27:02 +00:00
"<PIL.Image.Image image mode=RGB size=512x512 at 0x7F946818F890>"
2022-07-04 14:21:12 +00:00
],
2022-07-04 21:27:02 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIACAIAAAB7GkOtAAEAAElEQVR4nFz9d5yt2XUWCD/PWmu/7zmnwr23bk6dg7rVytEKtmzZFg44exgGEBgPDAMmOgDfxwC/IXwwzDCY8efBYBswYJvBASMHCUtCkpVbaoVW6Fbn7ts3p6pbdc55373XWt8fp9rMfOfP+lXVeevU3muv9aTNH/iP7x+uDs1oRBFkVgrH5mjRyXSBbgFb1Hp8rdx9oju6bqfX9fRWd+zw7Nc+88I//pN/b+fx9wMJFKAvxx44cPquxfaN1ubD1V3MbwEjoCzaHZi+6evf9D/+pW/7yqO3/u7/9M/t1ry1r+Xmq3HrRcTzAACeuPP1f/Fv/Nnv+953nTk0+1c/9b/+ub/yU0RkHUDBwTOYHcS1C1ieAwwgcPDt3/9d/+gf/Kmf/5l/f/SMX/vKlf/8ng8/d23u4wis6+bJ2+44nvObPeu3/bH/5o3f8pZnzi0++fSNr11diKVIii+EUEKFhKtAkchWNIWREQRGbxNhqjeINoGgCSSyE9wSTFAmuVwCBiHRMuAqXHrCRBXdyGC0Er7AWHLdUkbLDJkwd7OG2Hrm4HPSDOLh1dmJm+WQmHJtjByilQRkqJBp612WwaIEMm4x1twklyOjiGV4C6neZl0OEYKeIQu2LowY9uCas56jZ7SUTHrz1M5Lv1cOPb0zwcJBIB0MxABRtAABEkx4o8wUaDEgBZnIhky6JhYQQRAoIMAEHHRkTy3pCyQgCgoQ8ABGRIACGpDIhAQQaIRNIIQowpGKrkc2jA41UICGAAoRI0oHFUiCpGrWARHQDgGYgYnm8IAZUFEDInCHK4TIhhZgAQKtAYp0REUEkvARGhBBS4CIBlVgiShgAR00MMAABZloI0QhQK5+vMISALIARARAsCITUSAOElAwEA00tAoCQmQigQA4QhXJ/c8ZgASiAhPI6uNVIJEVFHgCDWogEEASqDCCKlIiEwDJ9AHsQCDr/o9HIhVmkGn39rN1XXMckA5LqGM5AgYqqJCAJGoCU8AhBIlsiIROEBUJSAEAJAKQCibCwAkaAQEqsiENUuAV7CBANgQgCh+gilREgEBbwhIB2AROBJGANrRE4/5j1wAUECwrujWs91BBJNpl/Jdfxm99jMA33tb9rX/0yheuLH/87335L/7YK4+tDf/hVx57/CN4fkTbwuQPf+OJt3897GCdTKZaTnbxjnLx3/6xnz5+dvjpX/+hL186/Ve/96fPP/EUgAfecfd3/k9/YbEznQ57G/2TBybP/Ph3//bYgD/zbjz4rTh0FEGMA4YlOoVXaCIbCITuLw8jfIR0SIE3UOAOJkLQKcY5CFCRASbGAT3gDbaGEIwOABgQqxULEAju/weloC4QDjWkIIAASoM7aEDZX1HpqANsDUaMS4AIoiqsB3uM67bwWVcie/E6KFnAlgA1xaqoB5yxsTbr1ibjANXd7Z0bz35p57HPPP3bv/ofdp75MlCAE7CtYy+/78QdJ/du3ZCUmxd2Md7A6qFI2Thw9OyRtU1Z841bV7+GG5/G1t24Kdh+FKgAgQTy4jOffc//9XNvf9V9973uLdpvou6mzUBisompYb6NaQGPYrEABJg/9oVHf+89Hzz/3OU3vvEdD5yN1Hz0Kxcff+zitWt7vrt95ardfnrtkMUH/s3PvvjFz7zru7/joQM6Xw5DCtBGphDc32RBBAWdFsDDMylAkpmgh4pKX/qKsYiIJBxmuaE9FlpthEdkBjgt0kKYmqoBRkQHE5oQPTslQzxaNvdE9p2UajUqKJ6RgomaZLZsQkuRxFiEJWz0gZKz0g2KlkxvLVMLN8t0GOkcPZnShDKRUjEmVcVgoumlaU1Tcq1MGOKtSkaGlK4L0Sa5G4Em6CqEUENUjLJflSKRCQKlwEyBiIxoGBsYoKE3jAQJVSCQCcr+UhaiCLwgHAlkwBQdMBBaQMIKoiGBFKRjYigdWoIBJETREw5YARNtBIC1CXIJD/QCARWaQngFEQpNwECgVQjQ90BDbeg7sGEYoVPAUQeIohBJiCKJCGQCBZ2hAbUChCQScO7vVTOoggCISCBROulLtAFZ0RIUqOwfPFSoIhPB/Z3piWkHOjLgiQhY4WQGbxkDYlXoBZIYA1SQSCAakgAgCi2Q3C+yHhBhPzWU6nvIBBIgVdMTGVCFEBGIzCREUHoo6UwH0qEm3STCmW0UIgfMAACmaIG1AhoS8IpMCKAGUWSCCR/hFaXH1FABb8hAJgoRDR5IQWcQQSYi0BxDRdejE4QiKloDiOkEvWFekUQ0EICjJDxgBcXgATFEQx0BYDpFAtWhggSGEZ1iIjAACRXUbuv47Iaid3z3bbpRjp68Z8Mmzz3yqeW7v//s1vrjd57KnRu46sbucGwe3ZvK0NktlInL+nZZr/XI2pGDswc+97tPnX/iqekatCtHXv5atq0Z/cBsdqjzL3zoc23E7O2vXLz+Hdkr1gLDiJhDHUFAMetQgQgoEAEmkDCDGpKw1RcrItDP0ClYwIAnCEjAABLSQwSaYCIIdwTRdwAQFd5W6wd9gbaX1phDZb8pgYEGAu6AIBOdwQoKgYqWEEWZQgRjoNs2GRaBFi599UyHOxJmBV3XmpjoNGVNFUOb1+1Pfvb3fu83Hr16ucO5LyG/AijkZXrqwbMPHT99uD3/tcevXLo23tqJWwOaAAk4+o3b773v67/x9Q+96q5F6d/7uUczve7eQAxAAwBsADNgDtz61CMv/NTP/8JrH3poHlOgoVUAyNQyDSLHAVtbuHQDuYRfaLeeef8nv/j+j3zt27//Ha951Sv7vr/99icfuOeFjz3yxLNPX927fPMiuHHn9Pix5VOPfuJ39q6+7m1vvY+bzw39XFSLeGRkZIQSbajsVEwyWikFgYgWakpTwKM1BpIKuke0sOReqZNO+hRqAVhzdE9KN51M6NDAzMqYLWqq28iWzHBPz7HWVIsxqsVEJwg3TDwrgxVRxATq4QZJIgG4SMgtXazaTI8grauxbfOuYwcV7wIlwiMiUgsNiQy3lEZ3j0Ib2tJUVYQ1GEyNjNAI01WTQkAwBgSc9AhIIjIQsCwVY7oPWFVnQApMAWQmSg8K1BBJZK7GAiZEMxNqMAOwX2c90U2gBgRqggoECEDBDkmoAwQNSqwmPwl4gwFUtBFoKCYe6S2dnpW5qly9aB+ZgEMKQLijNWgHEg2wfn8O6HukAYIaUN1vx7rpfs21jn0RCDwCkc1LmKNFVoIJAYUQaJeS0ZoIEkYzMMMbtJfCSFAUYEaTlGzMkqAACSqZkp0T6cFONIqgqMiQNSN0ogAFIomaDUjLMnJACIFkECKwUElm09RiTO2iDDJGkCQhJAARtdTQlEQ6A8EUEWY2RWoCAFPA2nJTEQ0kvIGrzocAk0oRoToqYgn4/qSCCVgwOopSNJNIIirSoAC7lxpSB4gQWI8kakASnkiDdWhAq7CyvwCqI4gklPvHDwIxIAIuEEMCY0UASnigN0CgI9oANSgw1YPH7ryeOAH81CcX//hPvP8nf+q7jhzZ/OpjlxNbb3xQlkfj7Zf7v/ufBrl6nePYOihbbWPbHqd7azpbO3WmXr387O++59fPnj78zX/2+w/e+4rDBw8WYZ3qsDj/2Mff86s/+0mZ9JtvettyupFYos5XJ+/+YyOwbFCBKiIhWC3N/d4Iuj8ir8ZlEC3R9YgGITKQDhiUsILmQMIEnmgdjGgAE6KIBAxiqAHrEAIH2MEDIITQDkFwdXwSNaAJEiHoppAGJ5JgQhtA2xIXDpOqWC4XylFS1KqwJkGmiKJ0XPZ18fgnfvsz/+YXRn8dNirySQDAqfLgQ6fu39QbLzzyka8sblU0oCXKGkIQVwA7fdedr3/Ng6944IH77run32gHTh4BgOEK/utrADaAE0DzbX7iU1/9wlc+lp1gegQLAiOG4R2vfeXjN5fnHnsC2WHtMOqA+bXdnVsXz53zS+fmN+qRQ0f0fusmh2cHTvbrhx85+MyXvvritct70vjql03Pnjly5doLX/zEh1/+ildpOfnkMF1qN
2022-06-28 00:58:17 +00:00
},
2022-07-04 14:21:12 +00:00
"metadata": {}
2022-06-29 17:55:23 +00:00
},
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-04 21:27:02 +00:00
"CPU times: user 33.9 s, sys: 551 ms, total: 34.4 s\n",
"Wall time: 35.3 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-04 20:06:14 +00:00
"grid_size = 3 #@param {type:\"integer\"}\n",
2022-07-04 12:05:55 +00:00
"seed = -1 #@param {type:\"integer\"}\n",
"intermediate_image_count = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\"] {type:\"raw\"}\n",
2022-07-04 20:30:39 +00:00
"display_size = 512 #@param {type:\"integer\"}\n",
2022-07-04 20:06:14 +00:00
"\n",
"image_shape = (display_size, display_size, 3)\n",
"zero_image = Image.fromarray(numpy.zeros(image_shape, dtype=numpy.uint8))\n",
"display(zero_image, display_id=1)\n",
"\n",
"def handle_intermediate_image(row_index: int, image: Image.Image):\n",
" image = image.resize((display_size, display_size))\n",
" update_display(image, display_id=1)\n",
2022-06-28 15:05:59 +00:00
"\n",
2022-07-04 20:06:14 +00:00
"image = model.generate_image(\n",
" text,\n",
" seed,\n",
" grid_size,\n",
2022-07-04 21:27:02 +00:00
" log2(intermediate_image_count),\n",
2022-07-04 20:06:14 +00:00
" handle_intermediate_image\n",
")"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-04 14:21:12 +00:00
"provenance": [],
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-04 14:21:12 +00:00
}