min-dalle-test/min_dalle.ipynb

232 lines
656 KiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-04 14:21:12 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-04 22:40:32 +00:00
"execution_count": 1,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-04 20:30:39 +00:00
"id": "ix_xt4X1_6F4",
2022-07-04 22:40:32 +00:00
"outputId": "ff5b8a25-6426-4dc1-c66b-e0b6817275ef",
2022-07-04 20:30:39 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
}
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:30:39 +00:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2022-07-04 22:40:32 +00:00
"Collecting min-dalle\n",
" Downloading min-dalle-0.2.23.tar.gz (11 kB)\n",
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.11.0+cu113)\n",
"Requirement already satisfied: typing_extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (4.1.1)\n",
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.21.6)\n",
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (2.23.0)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2022.6.15)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (1.24.3)\n",
2022-07-04 20:30:39 +00:00
"Building wheels for collected packages: min-dalle\n",
" Building wheel for min-dalle (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
2022-07-04 22:40:32 +00:00
" Created wheel for min-dalle: filename=min_dalle-0.2.23-py3-none-any.whl size=11271 sha256=51fff31edf2bb31ae51f08052489ba646f8f56df13021ce56c8b9b5c263850a0\n",
" Stored in directory: /root/.cache/pip/wheels/92/13/66/ddc26f0e09a54d6541a0fc5e6c7b5896aac0c36a0d9a34cec2\n",
2022-07-04 20:30:39 +00:00
"Successfully built min-dalle\n",
"Installing collected packages: min-dalle\n",
2022-07-04 22:40:32 +00:00
"Successfully installed min-dalle-0.2.23\n",
"Mon Jul 4 21:51:04 2022 \n",
2022-07-04 20:30:39 +00:00
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
2022-07-04 22:40:32 +00:00
"| N/A 35C P0 27W / 250W | 0MiB / 16280MiB | 0% Default |\n",
2022-07-04 20:30:39 +00:00
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-04 22:40:32 +00:00
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-04 22:40:32 +00:00
"outputId": "6b2f90ff-0f97-400c-a3c8-d2b600897028"
},
"outputs": [
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
"text": [
"initializing MinDalle\n",
2022-07-04 22:40:32 +00:00
"downloading tokenizer params\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
2022-07-04 22:40:32 +00:00
"downloading encoder params\n",
"initializing DalleBartEncoder\n",
2022-07-04 22:40:32 +00:00
"downloading decoder params\n",
"initializing DalleBartDecoder\n",
2022-07-04 22:40:32 +00:00
"downloading detokenizer params\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-04 20:06:14 +00:00
"from PIL import Image\n",
"from IPython.display import update_display\n",
"import numpy\n",
2022-07-04 21:27:02 +00:00
"from math import log2\n",
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
"Note: reduce the grid size if you run out of GPU memory. 4x4 has been tested to work on T4 and P100 (with intermediate_image_count = 1)"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-04 22:40:32 +00:00
"execution_count": 6,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-04 20:30:39 +00:00
"height": 563
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-04 22:40:32 +00:00
"outputId": "ac57d1c4-b6b5-48c1-d047-5db3279c1298"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-04 14:21:12 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
2022-07-04 22:40:32 +00:00
"<PIL.Image.Image image mode=RGB size=512x512 at 0x7FDD4A28CFD0>"
2022-07-04 14:21:12 +00:00
],
2022-07-04 22:40:32 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIACAIAAAB7GkOtAAEAAElEQVR4nFT9Z9xt11XmiT5jjDnnWju8+eSgc6SjLEuWbVmWHOWEsTEGjDEGA64Cim4KqNBddHVVd3UlqitXNVDkAopQZBtsnADLCVuyZUmWLCvr6OjkfN64915rzjnGuB/eU33vXR/377f2Cnuv8cz5/J8xF+EPPoVzM4CgGaGibMEBSkCAG8hRJuAWoYX14ARXUIUzEmNzAyRIEa4ghlW0jKoSB15h1SJLtQ03HjZzVnvEJhBXqx1heTBc0I69a9vWnURE3EfJlMpiaBdZiEBOQmZJpqBWbWCWwcvBlxKJC9hcrHWjmg/vPLA0WlmYG6Vqly5uLC7tSphubJxluRyH/Zf++qkc256ln+UnHnxw49LGLE8vvnSsYZ5YGA3Gu3Ze+8u/+Bs3X3fkiaee+JG/8d5HH39hNGwO33zzwUO7z754bHNzeterX/26199DgaXx9dnm+pWZIQ9DqMrr027Xrh3t8vLCzmEzmGviiNiu2z1/zc5dIc5tzHh92m1MNlZXz06mV1avnNnaXBsEs6IVat5vdVkshrbt+5yzhtAwVdV+WuAVKfZbvXMZk3NlVQN771S3PIwd7H3nwiaGkjO7OoUt9uiaSu+dawRq3eqVG4yAqTrcqGqpNRsnFp/tve3RwVvrWl9CgOUmop9OhUWKcBCSYKql1mEcLRA268xcEtXsdZp9FFKebkojHKUUVgO4kNbq0KmFwWC8MKhMM+2bcKn+8X/JR59lxvwcXv/Ob337+9533S4+smNft7n2C3/wK8Pd8Z1ve6+XZjLJXb/xtYee+o0/f0D277vrdW96Yq2Z7b9pMt7vC3NYHmHBYVNsbWBjA/0EswmUoIyuQg1a0Rk8wSP6gr4DCVKCO0jghhBABEqoPVgQGnBe2bPxP7f3P/eHn/7aC3Prt7x2rbnWi2GwgDgCE6YdJGFhAVEaKnlr1d3AATFAFR6wModph/U1zA0wGICAqjTrDy4MD9PWG3ec+Ovf+9Rfv8jtt/2Uz+8sprUDG9ush2YG23QKHWOSEQS5gxCowgTVwRUEWA407541OoCgWmEcW9vaQs1oA+AAoTQIE0SHRdSECgQFGLMN8DAOFmu37l0FB1RFcDDQG976CjRLKAR2sAIVJQMCCAIgEV1BZFADrSgZwoADjkqoBXH70AEI0Apx1BmaCG2AgpwRApgAR61whSrYoA4RqIIIpaJN6ApSghEEMIAIISIXEMENTogCJsBgBmFwQGR4AQgEqEEdMESDKhAhAgKqwQ1BoA4WCAFACCBDZdQZBAgMZxhQAakgAweYQAF2iEEMmSARKGCBG4RQHerIPdoAJqjCBJ4RGEZAhAOmiIJ+Ck7gCK1Xdw8EB9yxNkUkNAFqAMMNyVEMvcAiSBEAzRhETBmmMAMzyECGCmz06KeYi6gABO4IGXBUQd4+T4ZlJIYLCjDpESJqRT9FjBADhgEzBioagQUMBNMKMhhDFcZwBTNiRBRwhDlqRS1IIwwS2ECOaoAgCYrBHTGBBQ2oCVTNK4FDHA1AoZiXUtVKaIZpbhA9kqdiRdiNMBdjpFlDoUmtkpDbSKSHT6hOiOaGo6xZRczzxPtlNBwYRIPGRwMeDRZTaPoO4Ha01ECChNHi4g6mOClrN99827TS+Y2NC9Mzh3btvCJyakOuue76aFgtdfXS+oMPfeVv/q0f/ePf++28Nbt8cXVhIS0tLUyuXNkct29+630vu+Nl+/bsfPHFpx559Onnjz575syVvssZRYLB83SijWqvsmPXeM++RR4vrPkSsl+3Mn/XnQduufX65T2HSxhLFIQ2zS8sj0mnG7kPyKXkKjEFTxU1xiGzu5jWUo1D06hkc5ZBHKSRZTfpWJndM/EojZdncUM3mUuTBzV0oeW2xnWdwQdkMTIb+lSYEwtj5HMGrpU0G1Uvw5qrR+FVayY5DEZViII0EcYybEJg41Kqq/elgkzESIiZWHXa97OioEhDGTWjYq6mkpxBOau7UxPioE0hBTISGKVmtIsGIQNL83jV7XTqySeeuf7m9/zYexbF1jh+97vetrTzmqZdzN7Nza8Vi1oPPbR1rBveds+dNzLiZx48gU9/FnffhbfdhTYhF3hGo0CGGBSYdNSQFwVVNAEwFEVISAxVwEAEIoQEZhhgCgYCIQGpWX1pwu3Ff/43X/VbL+z68zP7q7WTrqJpzYA+AwXjOTQBlvuNTaAiJjQJpqCA1MIdKWJ5ByJDAOsg8IX2dJf3zPdtv0EXL2Azze1qfNysXdwAKWJgUy8GBZoGiEgMECYFbgBACnE4QA6wszsI7KjmZDDyGDEeogDicAeBBuy6fb1AS9Dt0mwIDDjY0GxXQIYJ3FAzUJAZ8wlDoCqaBOthAgfUUR1kGEcEQmphBR1BK6rDgPkWxuh7GGMAiKECHEEKNQwbSMQmQIApQJgfoJtCCSAkAgAzmIMEJBgKHBiPAIMWFEcKIAcTSkVK2L58daQEAVKLJOgJ5nAHMwaCyJhuwQnNAG4oGW0DAawiJIjADAK4ARFLQ3SMXAAHMSQiMPIMRa8ezh3uAEM7zI3AgDoAKADCeIBAqAGlQB0hIkX0BnMIgwzVEQA3xIAU4A4nkMMUaqCE+QEagVdUA4CQYBV9h7bBIMAMFuCAVUynaOfgAAhqqBVOWBhiHDFjqEMNInBDEagiMYaC6nACN5jNQMAgIkZUgBKGAbotJzkgV7BABQAmU8QW5KgEcVSDBRADBu+h9eo/Iw4gQM6IAaY0iHByMwSEEN1dSAyomj0K81ACT70AynCEMByMYpNmVAfEC6EJJPMWyBzsTu1CaEJV45ip91wTYSQkLbNARHYz76mDTcOm9OsNLXoY0HDI80NNXJVDKfl8nV66vHV5be3S+sb6lbULWxPvKm/1ZXGUdo4HO3ctDAe+1M3Pejt9/NLmS6cnmzk4feWzn//wH33kNa95+bRujUcjrzZaCG9402vuvP2OJx77+s/9x//8+KPHzR3/v1sA5lJsZaqW1+rm+tq5l9b27hos7DqwVfOXv3Hmy39Z5uebG2858KrX3HHHa24bLVzTluHMqItIoU4ka4ik6PpZ0X4QLUotOmMJwRN5yLxFkNbIpZcBDx0iTbUKbbxqHklLw2Q1yVzPsZBKxIjG6MdSDUwTKkgY0WBSei1WBSZViCASGkkapdZoAFFfjcyKQKESZdbloO5AVgya4az0G32/xQhe2FwNo7lFRe5qjYFSoGBUVNW8OkITa7EQal/L9sNegGZul+26HnhydQPjeOB19wx/47d/bxeH+162d+fyeLnfoFOPbUz5XJmeOj/58vHyl1+7/Orl2f/1ga0az5585KkbP/xHS3mzTBZeeOD2jWtvwT2vwY6dkBk0szo6ZXOQVytwQWzQG4gRWrCjn8IKKEAY2qOG0AzVOh82qBk2xcxat68/tf6Dtx+6rek/cX61Wd6VpWoIljNqRTPEdAt9BzLkGahgfkhZBzqJEiX7qouzYDiWJtBsq/aFInmFprg8zC898YUzF1fNdmLae6PGBLbommLsa2+uFgcxzVmdabZmcd5qhkHMLFgutbW2D1M1I2MoSQigzGFo4tImFLha9KZyRx4qtxzgLIDByDOBCXEI4qId3EGMUqERRDACRVBAiNAMCSgFYASCE4QgCiO4oFbkHlQRADCkBRuMYIw0RBCAoA6qcKAXhHh1UjI/hBrI4I6ikIhWwAIzFAMUHFB6KIEBCWBAFQ4MA6KgEVhFJBDDHewQQYyAQYCi4O2ZQYAZ4HBCbDCKEEYuaIYgQt+jKiLAgBBMEQQSUTNSQgoAQQhVUR0xYjRCUUCuKtm2PjngDA4AwAHM0IoqIEZs0EZYvap/gQGABJFQMxCACC2w7SpKUAdHsGPSQwB3SAQZ1KEOSVDAFeogAwJ6QAS5AAY4zMARQpj0iIIgV49oBhWoARFWkSsAmGJaQARVVMArnFCBkDCK8IreAyiDGEKoFaaoBlcogR3mCPHq1EYERFCgIYgjAaXAXYiSe2E21ODaJDZOriWCPboYh
2022-06-28 00:58:17 +00:00
},
2022-07-04 14:21:12 +00:00
"metadata": {}
2022-06-29 17:55:23 +00:00
},
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-04 22:40:32 +00:00
"CPU times: user 31.2 s, sys: 159 ms, total: 31.3 s\n",
"Wall time: 31.1 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-04 20:06:14 +00:00
"grid_size = 3 #@param {type:\"integer\"}\n",
2022-07-04 12:05:55 +00:00
"seed = -1 #@param {type:\"integer\"}\n",
"intermediate_image_count = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\"] {type:\"raw\"}\n",
2022-07-04 20:30:39 +00:00
"display_size = 512 #@param {type:\"integer\"}\n",
2022-07-04 20:06:14 +00:00
"\n",
"image_shape = (display_size, display_size, 3)\n",
"zero_image = Image.fromarray(numpy.zeros(image_shape, dtype=numpy.uint8))\n",
"display(zero_image, display_id=1)\n",
"\n",
"def handle_intermediate_image(row_index: int, image: Image.Image):\n",
2022-07-04 22:40:32 +00:00
" if row_index + 1 == 16: return\n",
2022-07-04 20:06:14 +00:00
" image = image.resize((display_size, display_size))\n",
" update_display(image, display_id=1)\n",
2022-06-28 15:05:59 +00:00
"\n",
2022-07-04 20:06:14 +00:00
"image = model.generate_image(\n",
" text,\n",
" seed,\n",
" grid_size,\n",
2022-07-04 21:27:02 +00:00
" log2(intermediate_image_count),\n",
2022-07-04 20:06:14 +00:00
" handle_intermediate_image\n",
")\n",
"\n",
2022-07-04 22:40:32 +00:00
"handle_intermediate_image(-1, image)"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-04 14:21:12 +00:00
"provenance": [],
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-04 14:21:12 +00:00
}