min-dalle-test/min_dalle.ipynb

216 lines
654 KiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-04 14:21:12 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
"execution_count": 3,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-04 20:30:39 +00:00
"id": "ix_xt4X1_6F4",
"outputId": "3aeb9c3a-09f2-40d7-f9c4-80b7d1ce0712",
2022-07-04 20:30:39 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
}
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:30:39 +00:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: min-dalle in /usr/local/lib/python3.7/dist-packages (0.2.24)\n",
"Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (4.1.1)\n",
2022-07-04 22:40:32 +00:00
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.11.0+cu113)\n",
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.21.6)\n",
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (2.23.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2022.6.15)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (3.0.4)\n",
"Tue Jul 5 00:03:52 2022 \n",
2022-07-04 20:30:39 +00:00
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 37C P0 27W / 250W | 0MiB / 16280MiB | 0% Default |\n",
2022-07-04 20:30:39 +00:00
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
"outputId": "1ad581be-0649-431f-ebda-8fdac89e17e9"
},
"outputs": [
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
"text": [
"initializing MinDalle\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
"initializing DalleBartEncoder\n",
"initializing DalleBartDecoder\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-04 20:06:14 +00:00
"from PIL import Image\n",
"from IPython.display import update_display\n",
"import numpy\n",
2022-07-04 21:27:02 +00:00
"from math import log2\n",
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
"Note: reduce the grid size if you run out of GPU memory. 4x4 has been tested to work on T4 and P100 (with intermediate_image_count = 1)"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-05 00:08:31 +00:00
"execution_count": 6,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-04 20:30:39 +00:00
"height": 563
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-05 00:08:31 +00:00
"outputId": "8939a5b5-2bfc-4267-b4f0-49d55f43f64f"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-04 14:21:12 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
2022-07-05 00:08:31 +00:00
"<PIL.Image.Image image mode=RGB size=512x512 at 0x7F56D8D23F90>"
2022-07-04 14:21:12 +00:00
],
2022-07-05 00:08:31 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIACAIAAAB7GkOtAAEAAElEQVR4nGz9Z7Qk6XUdiO5jvojIzGvr1r1lurpMd1V7b+C9J0HRgJQoSjQSzTxJT+KMnrxZouYtaUYciZIoLRJaI1F0ogEJgARBAgQIDzRMA2i0rzblq8vX9TdNxPedc96PLJB8WvP9zpUuIo7ZZ+996Lafeb9tdxNGWFEOZSNmDiaww4WlswlzPSt18YmBE2DhY8/qlaADOUgpGMRdFCUmlAxSqivkEiCmIHgUNxByrQxiYTVHhCdGxLhAGqrBmVkl3MIn5DNUKzpniCgCABcvPY2WjFH3wIWLkVQUhNj0XGduOHdCXUiP0iRygBIRwytlJqkdINPERq1J7rhsEiWemWmWZrg/CFkUn6XtQV5daHIPNNkZOnTXyh4ZzDv6kUElSjfc2dpICVohT7rBYFHrOfdUvLi4WRmP89W1dn203bJ3vd3XqV4veZzHMd6uKcQtt7knSpBwT5qsTIAI5gqJyQuHObHBCZ3lmqhjk6hr1y65R6rILWyI6IVQFBMSENyLI1upFR1CkRjcUYgzYjJxE+4p5Qg3Z0eEB4kG0/rKsWfjTmy0EIZ3gCMcQQgCM4hBDnNQDQACQECOMBjgDhhUASAAYlABOxzgCg6AQYFSYAxzFAISKKCKyQQtI09QAyRoBUZAC3FsD4FAyuiGGGY0gfkeRluoZ297u7/9S/9x/5Pf/NzzuJwwt+/YX/qBNz392cce+8aL6ZYjMw2tnd54z3e+5tZjR576kz956rkXDx3Ds+fx8mn096ef/5n7vvHpp/fesjIjr//n//qj/+Jn33h461t//z9cuzSCAG96y9z3/4XXP/vYxV//2HPDSQEAgAlHbsZ3/fi7b7/zLa++69HPf+JPFpe2/tW/+K3YtfW3fur7j+278+xm3dv3wL/523/v5KXTP/Eff+rO/bd/8wvf/PpXvnT/fbuf++qlU988WwE7gPT787/1hTXai/EYY0MU5BbjMWZ6oMDqZn9mIr/177e/8OmVH/nJ1YPfb21BJESN4iBGIUiCOTwQHSjQZRQgAmC0EzAhFOEIBwUIMIcBEWCHBxxgnv4muCMIqYIQhtXhN+5KM7l0kcPEy3zK4zz2wo1UJFoKRu1kUjCjjUYxCQqC28isShqTrutGvX6lzO6eLaoabsGQRFV2I4BRMrqIuh8WaJ20IcqRJ4EmSDBRBYuEe5QgMiIz98QKEAiNSudtl4M8lLwW5DAyeOl6AvHJgngdo3G0qadeJmKTWqt6blEKmuiayeZsXG1UtLccPDfcGm6sXVw9//zpCyfuuPtNr33dq3YvzFY1gYTTgNPCuPi1q1euXb1y9vy5MyfOdaPReHtztL2dx2PvckpKItmcVapef3F3v6fan+1rr9594ODegwdvPnpsbunQZpvOX994/OT5b1zYWSspqqoKpSB3IkldjpLDDebsFm7kIZw9iCOCIijMmTXEvIsuQ+LGpTdFDMEAlJyjCFAQGT5CVBRVUIugG0+lZUSN0kIdLCgBMMyRCNlgBE2oErynOy31qIiABH1VMxCBiMNDAgRLIpVKzSSoLNxLgVtF3DTsJsWDKIiD3HtEzCgeGjSomCKxZ0eEuwpTAlyFCSBhJjIm5vDiqFn6qs7kEe4B91q1V1Vh4WGQACABFSJEEkmiNSUhCRC5W5Ra00xVhYl7UVIhUYQwEzwFEotWaTC2YCLiMU1QK9c6O0iVzFRc1661A0RV8AxJQ9yTJvXrEuNue2vz0nh7si3hCTLTb3bvXqmbOtXJwsbbZbSxM2zH8CjRuo08hpVJv4wHPO5hbU+1a1Wbq96VGSHkMikuDAMRNDTnzOScGCI9rUruzJw4guDuTJh+4UpTFepSyCkC5lElmUn9XNriFoEQF2YWQYQSaq2AiDAulI2EpamVA7kEMyKcVEBwWHQO7mGGgEAQKGAdHAgCMYIQjARIhXAQgIAHgtHvEXl0EwAAgRhCcIACxIDCCQGEgwPVAADagmCEgxmLu5EEOzswgwjqCkzIHQpQL4MJVKMRsKMGmgoBkOOVD9186pmbD+PyCVg0b3rLox/+1JNf/swLN92yLAQrcn6U//Mvf3JlZddiz0HUrcegBwC0KYoFH2fZXnnszFqL8S3zg1/69dGl0TTU47Of23rp2c//3R+448e+69j7P3g8AAAeOHkOx599Zf/utZ2t9sQLLz3ymjfNLUubBq9+9JGy3fz+L/7S7luu/4O//+P/7t/8h4//3B8c/l9/Ynl2+fDeo02zi2i1IkTcSCQLFY1i0gpFUoRCE0VI4pI79Htzvd7c0ZXtL2JpuLXZ76xidILCqBq4IRFIEAniaAMREIETiqEYqAYcpCDAp7nZwYRG4QYrAODf/j2sAKAVELBAbzySssSsPXDBIGmjiTpwcO7csgHe2YSD65SEUACz0uZWgnu1SlORs3uEGxHXcDNTFlURcvWgoOyGXPpVb1Bp9nALFLMoTUq7en3zsJyDAkyaYBYErVR6qWYQPDzMzPsqdVLPhbwr3ViU66bfeGGEoG1cZ5JSLUTVIBbUUi6sFImiz9FEk1IT0BQ0SPXWxPLG+rve9u63f+9fG4SPVs+vrp3c3Nko3gZVnfNoJD1dOHbT8uG9s+PRZPXq2tbG9mQy6kZbk0k73Bl3bW4nZeP6+rVLa70KTULq4fLLL5ye759cvunuh1515J77D2qVF7k2nBmzGVorw4646XU5WglUqRQU82FXLDRCuKfdpCUH4LAJINXsTLh04+we4RBRZs1DRslgCQ2kjCCUQFvAjKoCBGYwoBhU0KuRA6VDABVAgRIAIQmaGkRwQzVSozCWiFBic0pVDYeAIOHm00IVJcZcQC6ABIuKAw5ikUYSERGIBOYWgQROCMtO6hUrABIB3GGkLCSEYBZSdS/hJFQL6zT4JJBDKFWi7PCUVDkRQOAA3DMDCKKgAoOHCDxMgxsOU9KqmYVrqKEgkzKUZRApwrLnhnnEw00tJuhr6mlKRo3nhluMJzoeS7ddyhb6Fc3MbY6GFy9ffPxr33r8SydKJ6vbVxzoK997+8G3v/21r37T61K1kLvu1JmTX//6k26hpNeHw+H2Ti2j+bn+TL9qx0OVqpmbn1tavPPmwzsye23sHWsJhIhbwI3DNLFyDQJRpKRVogDcwj06LwnJPAq8RbFSGG5uwpU6JqVNLBWThJgnc7MgQ5Fgs+IRCA+EkDDIit2IvuYUFBIUVAtqACIIRQTc4QZJkGk5TxTMUEMHd1AgHACYoQryYEKvAoJEgggIGGHaDhADQBiCEUow4qpqIjy7wYS8MhKRukrIRhQAMYNrJkBTViESZw4VSgQl9qhrXn52/fqpbvUidu2Z23304Q984A/OXJjcf8fR1c2rx79xDSAs7OGeXL105SrQ62PPLjCYyD1waXPh6k69p7vyzPELAJ7+2mOPn55mLxiwsFQfOty+eOqJgzfVvQqjDn96vvLFs7fccn6Qvnrh4tkH8OKVzTzaas9dblOkE6fOXD3/u//L2//Ov/8nP/kP/s0v/86vfeTt3/HO+ZnZZD5T0YO34alTQAYDuyJfn4wjNEDk1rIJ+0LXdfAtKgvU7L3pgdO3fOXCEy8uPTC5yLOoArUiA5HAAgMCsEClIIIZzKFAEZRpuQcEkAhGYAERQDe6BAYYsIBPUzVAQGKEIxOp1jWDeSXIuPUAqpocHDYuXe6ckYhiezwWCUZ4iVpFkkTXQVkYdWIOdObuzAKtqigeVoIiACKaqWqVKJ4lguBBqCUJvOt2KommTgLJlnMUFR00NSIoEG7ZCtwkolIHPCXiUEUt6DJ1PcaCMJL2tb+v0bnZXqN1Q2U42bm2Nlq9PtoebYZsDZZqYo2iWivXVvn6u9716ne/729eX9v+xlc+8bnPf/HilXO1EpiL8+zMzE033XTH7cf27t43mJ3pz6SNzT0vn98cTgBqvXSWy3g8Gg3bnZ3h+vrw8rnrl8+vDy97f2a8MDM+e3r10oUzR48/ccf9Dxxc3lvt4upae3kcw6pfCRlry6iqKMFduBpIK
2022-06-28 00:58:17 +00:00
},
2022-07-04 14:21:12 +00:00
"metadata": {}
2022-06-29 17:55:23 +00:00
},
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-05 00:08:31 +00:00
"CPU times: user 31.1 s, sys: 125 ms, total: 31.3 s\n",
"Wall time: 31 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-04 20:06:14 +00:00
"grid_size = 3 #@param {type:\"integer\"}\n",
2022-07-04 12:05:55 +00:00
"seed = -1 #@param {type:\"integer\"}\n",
"intermediate_image_count = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\"] {type:\"raw\"}\n",
2022-07-04 20:30:39 +00:00
"display_size = 512 #@param {type:\"integer\"}\n",
2022-07-04 20:06:14 +00:00
"\n",
"image_stream = model.generate_image_stream(\n",
2022-07-04 20:06:14 +00:00
" text,\n",
" seed,\n",
" grid_size,\n",
" log2(intermediate_image_count)\n",
")\n",
"\n",
"image_shape = (display_size, display_size, 3)\n",
2022-07-05 00:08:31 +00:00
"zero_image = numpy.zeros(image_shape, dtype=numpy.uint8)\n",
"display(Image.fromarray(zero_image), display_id=1)\n",
"\n",
"for image in image_stream:\n",
" image = image.resize((display_size, display_size))\n",
" update_display(image, display_id=1)"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-04 14:21:12 +00:00
"provenance": [],
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-04 14:21:12 +00:00
}