min-dalle-test/min_dalle.ipynb

226 lines
1.4 MiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-04 14:21:12 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-05 00:56:45 +00:00
"execution_count": 1,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-04 20:30:39 +00:00
"id": "ix_xt4X1_6F4",
2022-07-05 00:56:45 +00:00
"outputId": "3b1436de-3af0-4121-8009-3672f5a03547",
2022-07-04 20:30:39 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
}
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:30:39 +00:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2022-07-05 00:56:45 +00:00
"Collecting min-dalle\n",
" Downloading min-dalle-0.2.24.tar.gz (11 kB)\n",
2022-07-04 22:40:32 +00:00
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.11.0+cu113)\n",
2022-07-05 00:56:45 +00:00
"Requirement already satisfied: typing_extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (4.1.1)\n",
2022-07-04 22:40:32 +00:00
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.21.6)\n",
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (2.23.0)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2.10)\n",
2022-07-05 00:56:45 +00:00
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2022.6.15)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (3.0.4)\n",
2022-07-05 00:56:45 +00:00
"Building wheels for collected packages: min-dalle\n",
" Building wheel for min-dalle (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for min-dalle: filename=min_dalle-0.2.24-py3-none-any.whl size=11198 sha256=d7c26dc9230a9b6b0e1e6d8026ec24192c248b4f3bcd61ad6fc3253449f45959\n",
" Stored in directory: /root/.cache/pip/wheels/25/95/af/94274bdc5e07b7aab46294f2bf36768aab35e4ed41371460c4\n",
"Successfully built min-dalle\n",
"Installing collected packages: min-dalle\n",
"Successfully installed min-dalle-0.2.24\n",
"Tue Jul 5 00:48:11 2022 \n",
2022-07-04 20:30:39 +00:00
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
2022-07-05 00:56:45 +00:00
"| N/A 35C P0 26W / 250W | 0MiB / 16280MiB | 0% Default |\n",
2022-07-04 20:30:39 +00:00
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-05 00:56:45 +00:00
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-05 00:56:45 +00:00
"outputId": "495ab426-c385-4ea3-c4dd-63c053f684e7"
},
"outputs": [
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
"text": [
"initializing MinDalle\n",
2022-07-05 00:56:45 +00:00
"downloading tokenizer params\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
2022-07-05 00:56:45 +00:00
"downloading encoder params\n",
"initializing DalleBartEncoder\n",
2022-07-05 00:56:45 +00:00
"downloading decoder params\n",
"initializing DalleBartDecoder\n",
2022-07-05 00:56:45 +00:00
"downloading detokenizer params\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-04 20:06:14 +00:00
"from PIL import Image\n",
"from IPython.display import update_display\n",
"import numpy\n",
2022-07-04 21:27:02 +00:00
"from math import log2\n",
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
"Note: reduce the grid size if you run out of GPU memory. 4x4 has been tested to work on T4 and P100 (with intermediate_image_count = 1)"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-05 00:56:45 +00:00
"execution_count": 4,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-05 00:56:45 +00:00
"height": 819
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-05 00:56:45 +00:00
"outputId": "8bc87737-27cf-4e6f-bac3-c87826314038"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-04 14:21:12 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
2022-07-05 00:56:45 +00:00
"<PIL.Image.Image image mode=RGB size=768x768 at 0x7F3C2E493FD0>"
2022-07-04 14:21:12 +00:00
],
2022-07-05 00:56:45 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwAAAAMACAIAAAAc45fZAAEAAElEQVR4nJT9abQsS3YWCH57MDP3GM45d3pDvpfzICmVQqnUgFKgRgOSkBjEqKYKWE0DRVc1q6ppqK5CDV1Ar24VUNV0F8UoFUJTNVDMEkIpJCQlkkgNlcpM5ZzK4b0cXuYb7nTOiXB3M9t79484cXVzEL3wdddZcSMszMzNPXx//n3f3k7ANwOPAQPQgAKcAwQo0HG1ZaADBDhAgAAdaEACGAhgARxgoAPl+BUA8/ErDgyAA/04kAAMXAL5+F8CKuBAAgyg41iH95fjWAUgIAMBzEAAAPw4SQYSMB8b7IEE0PGjfmxpxwkLEEAHKsCAAQMAQI/jHrYGrI97ugX2QAYYuDi2D0AAA+zYPwAGFHDAgOW4DuXYP479H7o99F+BDbAHBFDgHBgAASqQgA4EkAEH8NA6+HHEBMjxBYDpoQUcjm0KMAMK8HFHDovDAIAOKBDHPVKgHQ/cgz2iY/+HefrxqOE4PXvo+Kbj/A8TOJwh/XiOMdCOf3GchgAJcGA57k4/fj2OpwcDBCyAHo/mYSnseOqm4ynkx0O8AtrxhFkBF4d9f91f2O46t8vexcdM9+6dX07zZpPBOodbFHUl8Ml2W1sTZaa+208iSSJH91yiEomObZ5z4WZ9asbO6yzk87w05BQED+4eRJIoT9O9R66v+jR3FgrhTmZMGbVZKWmeLpZpDvBKV2xNknceUip9aTnxYjbDMUVGFK5O1llJcwRaDyUCUa37zUaWZdlXrMuKjBGWlBoFkYD9Yj+N47g7n7DaoDvgEEEQIAgBFLUhFfiMMaHPCIY78gghJAEzEAjAG0xAjg7MC5hRBAbAkBJqRzcIY94jOlqloURrUMb5HmUAMwioCyBoDEoYRywTOKAJiXG5RxBUsTi2J/AZAJSgBXPFMKAaVOENDJChGaAQhjCmGUNGAqJjNhBjyHDHvMCBnMABb6gNeYRTKtzhMTVEIBpEEAYiTEbsop5LmWsLp7CAMYaMaS+rtS0XYAcSJkZ3rADOJCWWGdaRBTUQQcIxXaID6wFweFydmJGwXGLNWDqEEQpnkIIcrUMFNoE6KsMyOKAOSeCCZUIStAURaI5xwHwBIgiDACS0CkkwRp+xUkx7hCEKPFAKBkVkyI2bt7b3z016mannPNT9XS6psLc+szKB0SAgFprqvE4l+hLigUGstLoMAxYgSfE6K6Nbt4CQqDn61IlEReAgciM4LHK3y9WA3pySUCQyjeDOFr0WUfSpUzPPiVbhcxJrrEyJbMnCS1sqTJDYFLZD8iBVBrFEQwSz6n45P13laHMTYh18YUIj7aQqDkQNZA6vqJAhU1Y0VU4U0ZZCbZX66RASvt7Q6aoUFkeoxKYkZkuDMndhpairra/Xm098fP/+p54dS3v5k+ubjw2PX1+B19uB0kgffar/P7/jh37d6/Aff/OXr+z5kZ+0QukGlEylm21/6T2XP/ued/43f+HPFMGy1LJ+8T6X++1TX/byL/gzf+qH/vL/+3/Cv3f7uj/6jd/5HX/mPf/4n79w++M+f/Txl8TXft0fW73mP/uMZn/8d+B3/zY8ee1xyi9/rtTtY0++5kW/4U9/2/d910+989/f/4PtO/6rb/2830hv/Npvmu/undrjN1/xwkf6E5//atT5HT/3tuHGyWtfd609+/b82J8HsP2C33z2opefV+19l4fTVVrmvigBCO8XHSIUtS4llQhihTsCDHg4QGIivSMHCEtzJ1YO8vAW1mqikPAL1c5AEIcJXCKiUzcLJiFbDDMiwwVCQCVIGAMEcngHE5AQBsmIMwVeBSX0CfAjrElXsTMVuMBmYHUEEO0YljbHyDcACizHuLJ7KMKlY9hejqFrBPrxbwdOATrGpEM4r8fwmY8BVT897u6PmOlB/+0YRFcAAAdOgAp0YAsoUIEVQMeWh/BZjo0PPehxApfAeOzwMI3lGI8DOAUWYAO04/wPuOcQbu2IsQ44I46A4AHCUGB3fEeOIzZgAAJg4BSYj/07cO247CNgQDmCADoisH58044TXo67fFgfPR5TPLRcD/o/OUJAPAQQD//0uD4rwIF6XKLp2OdhDg9g32GXH55SB06O72yPnT/AKA+WdP3peGg54qEH869HzDc+hAUPqHF1hK16XH87rvkDvE4AgA2wHNHz4Zx8DKhAfCpdp76cDctdsmf201LWt67dfFTTVOt9CUYZSC3o/sVuVDmx8WPz8xU1gx9dbU5jeDZuN0csjTlSbRdkc5az8WzY8R2HpRioCLGJwy04LZMNZdWrOxdToEpKosLn1oR06MOFLK00le0j44027e/gopFY7YU5L3JXrUmsbt54MjYfu3im2kWyJERCgCaLWGrTnFpDA2PQPK5T15Loou6SqYVN0cr6hKjj1hY64u6MLpQosYoQc9p14OSUHBEFxWGKDpS00eGyzjQMEU4aYR0tYXuKcMwVZLg/gRgpoAnkiC7XT9FqbFe+21Nvq3Gocwt0fvSkN/e6IAhMQGAkIIMKyKECIvSOcaRhDDPcGLEEIkGdskQH0oqz+NABwhSc1KVCFOMmKbdpwuaRIrTszrFU6MClOAeWGWmgxCui3bwHST49yaLk3nptc0CHXDRqa1RhCUHl+rBMl0TdOkNySkJOgVSnHfJqXfNOBvO76FiVW5skz6fzcMTcWGRFeukzxjFvr9cX7mDYIhrACIYSWkVJmA1ZAYImsCEIkjEvIAer+tBtAu3BZ+NwHa1NaUI4WocwOoEYGrh2A3dnUIFPIAUxuoEUDkRFYjhjGNArygkowyu6Yyiwe/fi5VLuEk1KubZL2q6vr1e4vAvWfdhKINDucVmXUeS0qwnu8EVv+xfFekz5hbg/qvZ5ymrbyufca5KzYV0u7Rwg3XOEujDRQjUPpdZGwiqeQmdywn6j691SzXsW3TidozfZheOGnBVO53w/2Kl1UqwbVfaenaG3+uZebZPczcjC2WOnook25+1ytUqFPLJ0re7xyLhFTBN21WIltJDBY6VplzA3X2tmHZrviTFKYoCEWGMU367SarNKzEZNmgdlIh/gJ7oaUmg2zXRjI2cvlcTPfeDpy+eeu3E6yrVHX3bj2pplq8Ve8rp7f/0vv/G//+tv+Wc/+Itf9hp8w5fw2YtfevcSdaeILGHr9WNPvffnpW9OB1vobO68ybvHNp9/8Wz6ye/+vv+/uOQn/u6/fttvevKa7k78zuVUuT7y9vd9+LObnd3CrZt48mZ87NJvJnrJ9W/8uZ/54K+Ffl4BfHYX/9e/8i/+y/lVX/t1X5lX13f3Zm9bpQuhGxfTxfPPP/dVv+4rgRd/19/+AwBw/UWPvO5VRPfPdlioBF2MsWqYNlou6qWAxzTs+nRtu5VKCeku9soYPU02r1Nq4OfIboyrzQVuwyPt1TG23Dk6zUm1LBGUgxrCG4Ui1pIvp7qkBqITS5V5Rvegta4H1/u09JjBiZzjEJnDAAIRCkDnCloBE8oKkrBUBIEcaQVz8AEhbCAMSeiBviAKSEGBbhBFVvSCJlcBnkYIgxh2DNUB8CncwQIELI5B9wHlE4CCEuLAlLSH6AQ9RrLlGKTLQ9EuHopw/hAbIcB4xHMGjMewl4/A6EF0P4CGcuxfgLNj/+0YXMeHJmPH4JqPaIOOyOYQ9dtD/IQfiZByJMnkiJniIVJtOKLAw3Db42z5CBzzEQrYcacerMDqyJG0I2eTj6PH8V85LhE/BA7y8VuHk+KASA7ztIcOEB2/9QBnbB/q/2GEgeOb+fiCjnBWH5rGgxHp+FeO/w7zXx8H9eP8Nw+NEg8tVz4icjrukR/PigfngxzfebCwB6JrONKEciG23HtuEVrEz4YUQWnaD3o6ReTECHdYs77ZSsA+de+uCJ0Na2u2+P5Tce4lViTmaC2er611H0io33vs9NF+nzGspuZ7XxBgQXA8cnOcLu716K1TNRqZLLinVlIo6Ln9HYSMOrLLFC1lGpASiWaeW931qRkGzbldsOpZSm29RfC+LQ3R+qIi187SNLdw3+ZSA+h1cbu3z
2022-06-28 00:58:17 +00:00
},
2022-07-04 14:21:12 +00:00
"metadata": {}
2022-06-29 17:55:23 +00:00
},
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-05 00:56:45 +00:00
"CPU times: user 34.2 s, sys: 524 ms, total: 34.8 s\n",
"Wall time: 34.7 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-04 20:06:14 +00:00
"grid_size = 3 #@param {type:\"integer\"}\n",
2022-07-04 12:05:55 +00:00
"seed = -1 #@param {type:\"integer\"}\n",
"intermediate_image_count = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\"] {type:\"raw\"}\n",
2022-07-04 20:06:14 +00:00
"\n",
"image_stream = model.generate_image_stream(\n",
2022-07-04 20:06:14 +00:00
" text,\n",
" seed,\n",
" grid_size,\n",
" log2(intermediate_image_count)\n",
")\n",
"\n",
2022-07-05 00:56:45 +00:00
"image_shape = [256 * grid_size] * 2 + [3]\n",
2022-07-05 00:08:31 +00:00
"zero_image = numpy.zeros(image_shape, dtype=numpy.uint8)\n",
"display(Image.fromarray(zero_image), display_id=1)\n",
"\n",
"for image in image_stream:\n",
" update_display(image, display_id=1)"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-04 14:21:12 +00:00
"provenance": [],
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-04 14:21:12 +00:00
}