min-dalle-test/min_dalle.ipynb

228 lines
657 KiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-04 14:21:12 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-04 20:30:39 +00:00
"execution_count": 1,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-04 20:30:39 +00:00
"id": "ix_xt4X1_6F4",
"outputId": "5476920e-a25f-487e-cef7-c5bec2af6c89",
"colab": {
"base_uri": "https://localhost:8080/"
}
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:30:39 +00:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting min-dalle\n",
" Downloading min-dalle-0.2.21.tar.gz (11 kB)\n",
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.11.0+cu113)\n",
"Requirement already satisfied: typing_extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (4.1.1)\n",
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.21.6)\n",
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (2.23.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2022.6.15)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2.10)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (1.24.3)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (3.0.4)\n",
"Building wheels for collected packages: min-dalle\n",
" Building wheel for min-dalle (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for min-dalle: filename=min_dalle-0.2.21-py3-none-any.whl size=11251 sha256=3a2376a44b744efae76959cd51b32bcd47e6c726516a9e4f2249a2c09b84ef3e\n",
" Stored in directory: /root/.cache/pip/wheels/ed/4a/f4/0c03a5fb54f04f081359043cc71bfbcce2cf4994a6dd6b89a9\n",
"Successfully built min-dalle\n",
"Installing collected packages: min-dalle\n",
"Successfully installed min-dalle-0.2.21\n",
"Mon Jul 4 20:25:12 2022 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 41C P0 26W / 250W | 0MiB / 16280MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-04 20:30:39 +00:00
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-04 20:30:39 +00:00
"outputId": "b5a1f639-6033-4f2a-9f9e-00fc3e9ad554"
},
"outputs": [
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
"text": [
"initializing MinDalle\n",
2022-07-04 20:30:39 +00:00
"downloading tokenizer params\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
2022-07-04 20:30:39 +00:00
"downloading encoder params\n",
"initializing DalleBartEncoder\n",
2022-07-04 20:30:39 +00:00
"downloading decoder params\n",
"initializing DalleBartDecoder\n",
2022-07-04 20:30:39 +00:00
"downloading detokenizer params\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-04 20:06:14 +00:00
"from PIL import Image\n",
"from IPython.display import update_display\n",
"import numpy\n",
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
2022-07-04 20:06:14 +00:00
"Note: reduce the grid size if you run out of GPU memory. 4x4 has been tested to work on T4 and P100 (with intermediate_image_count = None)"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-04 20:30:39 +00:00
"execution_count": 4,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-04 20:30:39 +00:00
"height": 563
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-04 20:30:39 +00:00
"outputId": "1c92021d-21d0-4fbf-b1cd-3f0f9a86fc42"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-04 14:21:12 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
2022-07-04 20:30:39 +00:00
"<PIL.Image.Image image mode=RGB size=512x512 at 0x7FEC61BBD450>"
2022-07-04 14:21:12 +00:00
],
2022-07-04 20:30:39 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIACAIAAAB7GkOtAAEAAElEQVR4nGS9d7hlWXEdvqpq73PODS92ms6Tcw4ME4AhC42EJEQ0EkiykoVkkBVso2xJxkqWZaFsi2AJCxA5DgzDMHmAyakndfd0zv3SvfecvXdV/f54DZb9O3+8733fPTftfU9VrVVr1SH8t/twdIIQYRlk0AQCADgBDhFAYQEIoAJmuAEOA5yQOwQGMwwAISdEgzk8wAVmYIclOEEaaEFWkIABKyiEkoAOxFABCVwRAG3BAVRDgKIQAQGuQEZOEIMXVAFT0zQQwvKPXL7up67mw0/c8xu/9D+effBQAfrr1t947dWv/eGXn3Xelv2HDz+/Z+fBIyOTlbyw78pLrn3LTW88++zzlnD4uSOPP3v4sZXxsfVTZ1xwxjXz1eZaKoEJHFBCLBiMsNSgi1n+4oOf+L3f+5MTe04AQACma67j3Ia5fn9j09cLz4kXnLk+KkfG+Ojhb9756F3fHm17+69f+lM/9vVnFk7uXURp0SWUAjjGCRQhEVoAAxW4IWcoUFp4ggOZYIAVkCNPUARFQAoluKIkdAXZkVoEhwPqMELJkBYOaIATyMAEVZgDAhiMQARVsIMJqX7liw7/2Xs/P2nnso49IfaqolAw5ZhBWSdRKtVxRZvZPAcvWZs4SZrUSlP3vOuy1gQmt+JtsoVAgalhm0odWTD3wnTClIOupzjWLFo4xrTSnqx4XmSs3hFPB66ztcxjlxJp2KvWO9SpA6cYuNCKFgKHQMux4kDMUlfDimAxkEiVUoj1dGqbNjv1ey0lDwOu+4O56SoOH3/4hd/6xbueui/9yd/90fd+z1uTnjiRxg898MzHPvyp+778RVAKPdIWs2vWVGFa0tHvee2Fuw+Vr3/jiWpmrq6nivZ+4T/+5K+9+2fE0eXRvoPHn3vmsbz88D/87f86tnBk65lbTiz3qVm+9+v7ZsLkg3+Asy+i/gBl3IyWSFXalLucSnFVChKUck1ICnGpBSxasjPI2HNGXSEIHAFdUM8FFIKRGUpACWDV4gHFATUJ7AJLzuwBTmoKLxIACYGq1MFQQjQtBRYRIvOYDLnAgaLoCCvL+Jvqd28/eFXN2gGRjbyolQCBQdVjECdLhh4aJ1UhNrDbQsmRguTOyQNTAWfnkqwKCcxqwh4TNDgATSXDKuGoeUIgj0IFzkDJqBhlgoURqKBXgRWV9eLKxvHOn9zmr3/FNqe849vP/O+vPfvJzz4K4Ae/94IfuXn7wjj+3T98RSo+lLF0An/xXy7dvG7Tpdf+0Mml0HUsw/Inv/6Rv/m7WxQKAECM4e/++lfe9vY373zyOZLRU1/+1M/++8+k9fNXXnvxmRee1R567DOfenBpyTZftPWd/+ZHdz25/7P/6x9Gyxp7+I33XPsjP/QTmy58WWjWuU8euOuLt9/+1dtvf/T2u55unV//Cz+/+XXvvmV/2bc0TuNlJFBJnguMwRWSgggoEENO0AIHkbobnFAE7iCCK2nnuSABuQMyoiA7TNAVVICuvmAP1oEErCBHyTBHm6AJvQhzqEAdVQEcJvAamiERrLCCzFBCyQDBgZQRCUKgXkAXUXWoHUoIAV0GHCCAYasrSAgCYbgADnOYQSLqCplgGQawIwgiwwAhkIAIGSCGMcCIFRCQM0qBOoRQBxRHp5AAW31HAhwcUUVUfbAjF5QMCFgAQQhQQy2YrtHA2ab70xvMj3zx9vf+xt8/s/cYgPWXX/oTP/1Tl29e+9gjX731c/+464XDh44eW17M6sAgfvJv7/6H07/wqtdc/+Z3XDK3fnp9f/vc1IazN14+W68LqODkYCUziJkZLyy1uz/09x//0qd23Pa1O83BARLhBbbQuXXHD64cx17ux3x4sHVQtsyfliY03Zt62Y1ryfJdn/rb2WZx9tJXnRz00GYQkAg5oyYwgRTmEEYhgMA1KsbEkAkOBAAEFWQHC5oAMEqBGnKGALEBMdoISzAADmIUQYdTe8eAFkDgBWCE+J1M4LDwnaUOmOnEM5XFwLmeik7kHgAliQYJkYJGi9zQyJ0DqENlxVVL1fN+tMQpFNdcqRusjeyRTcjgzoFVg2pVdFHIBxUlMSICQtsZJAzr2ilBcja2gsjchJi51IHYLKGEWMxzyVWh0quawDOJEhtziFHWixGLCTKrNTSdWqml1xv0qZ5KoQ7NJg7THiKsueb8G37ll7d95Qv7rr74qi2nrd17YPT+3/pvX/70J0aTpTgzNZw5vZSVudOmL7juBi8Hd9518IUd+8648uJHdm5uBrNzc+t2PX34U//0lZdedvqrX/oqdZqdmdt01vm5W/eDPzH1kfe/PycbxkFSuuTyzd+6Y/cnP1fee7aroZSJGBOIAVcToiDOYlCQIbg3cbXCcSG4uRY0NSpBUbi6MJFRIHZzzajYXdTBTrZaDlWRRKEOgbu7Ggk4CJgBOFSjiLm4KgxBnIUAgEECEAhIBaRYmDBClB5XufRjbUZmxGBVr93NSIkb5n4VisMIppqLVlU1N+inLnqZGAjOUSIPzRIZkToHjlRMXJKDg0fpVb1+nsSu6+CguoGw54TVi+q0PgFSs2OstJzq/qUbN73ppjXdyp4//fMvffHzT0+fsfZPf+0Vt9727Ke/+NQLD+96/ZuvnFmzfuXY0tYNa3YtHtqyfsN5Z74KeVNDiNL+4wc/9E//68sKw3eOnMtv/Or7N03zxWede2Iyvf3F133/q+//n7ceuufu+48f3HX+Wduuuv7c++5+fv8Te7/+6Vte/dpXvej6y+6654m83H38fz+ycebzPzg/N73+ymzTm8648OwjK3sOseriXfcd+voHP/C9U5vWnPHKdoAFR2YrOYIcJChK/UgCM4M6Yo8Doxt7YTjBwMPGSwbIS/YEhB7PVcidjZahBDEgoMowBzNc0EQ4wIKSoQVUY9CHJrQjqMEBiXBHSTCDCJoaykCAFUwKmDHdR0ooBnc0FQhwQ8gB5IgBcBDBFL0K5nACEUzYa6fibgCBCAAKQSIYUEUIcAICoHCHRwR8Z+UJ0WEARdB3cEPkU6GcAHUIo2I4QQgGuMFx6kwyuCMwqghzFAUDTYUIOEEcZiw4g8f0/B3v+cMP7N2/BGDLFZe96e3ff/LBT/27/3zvgX0TAFOzGM5iw5lTHqaOrdiyLezY8fSOHU9/5cG1P/D9F5532SWXXHbpgKZEhSXB2C2Yk8eS+eQdj37s19/1T4/ftR8zmLmAY43QSI+8UlsZ22SFBhKLhqN7x7sfW/jI0QduesXp2zesn+0N52e3vObGbuXYoWMf+9v+OE1f+31Lkw7uiIAISNAlFIUQyFExpIYZClA3qGuYwR3mKAILUIMVGCMKzJEYOQIGMBpGKaeqe1eUBtqHOWAw/04CMFiGMwQgh/upVEqOFsa9VquxgbhhNOCAJgSYa1eZuUnF/THaCZ2ExaJcwONSCD2YjCy7RXCA1F1npuQUlKAlVVgxBHLAYzYI9ceOrjhZApTjoMH6xbJMtTZh2rRmjkxcEMA1wlRnSX1SvK25gqvCOicwh7iGGepq1jpyT6ZVm4lmc3FsnB1slaYpHGPcCjkDMtdwVdUxq153zZkvf9H83PyWE8t7f+Ynf/zOr94LYG7jWpmam5wYhzGfvm6464Ensx/txcHRE4sXR6xfN7Vxy8Xw8PzjO5/ddfgP/upz1ez8BadvC9RbP7/u4HjN+ddv/v6TCx/9mw+dtrntSR2H6+bWHPrknSuveG3zokuSh1iCq5pRQYCZE0itdAmtITAkuEswd2IDEcNJkOEUGeTwBCYrxZQS4KpGCl2N7VAgaKmjqBKMIcFRHFzAxEZu7h0gZgZyY2QqDIWDGcmQC5LKSvbOLLuFqpqUSU3SlcJEIVbRQcLJS4M6knRoi7qSA86MJtaNUNFS9YNrXygKswHJVJqGjIhh4H6Jya1vvUKajMEIw7oeVKWoulAMxaUOTJ6zOkTJc1XEvLpypvt31/X2Hb7tPb/41Ud2dWvX03vf/dobLzhrZm76od3HHjow2fn33zr/km3HltOaX
2022-06-28 00:58:17 +00:00
},
2022-07-04 14:21:12 +00:00
"metadata": {}
2022-06-29 17:55:23 +00:00
},
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-04 20:30:39 +00:00
"CPU times: user 34.3 s, sys: 676 ms, total: 35 s\n",
"Wall time: 34.7 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-04 20:06:14 +00:00
"grid_size = 3 #@param {type:\"integer\"}\n",
2022-07-04 12:05:55 +00:00
"seed = -1 #@param {type:\"integer\"}\n",
2022-07-04 20:06:14 +00:00
"intermediate_image_count = 8 #@param [\"2\", \"4\", \"8\", \"16\", \"None\"] {type:\"raw\"}\n",
2022-07-04 20:30:39 +00:00
"display_size = 512 #@param {type:\"integer\"}\n",
2022-07-04 20:06:14 +00:00
"\n",
"image_shape = (display_size, display_size, 3)\n",
"zero_image = Image.fromarray(numpy.zeros(image_shape, dtype=numpy.uint8))\n",
"display(zero_image, display_id=1)\n",
"\n",
"def handle_intermediate_image(row_index: int, image: Image.Image):\n",
" image = image.resize((display_size, display_size))\n",
" update_display(image, display_id=1)\n",
2022-06-28 15:05:59 +00:00
"\n",
2022-07-04 20:06:14 +00:00
"image = model.generate_image(\n",
" text,\n",
" seed,\n",
" grid_size,\n",
" intermediate_image_count,\n",
" handle_intermediate_image\n",
")"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-04 14:21:12 +00:00
"provenance": [],
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-04 14:21:12 +00:00
}