min-dalle-test/min_dalle.ipynb

218 lines
1.4 MiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-04 14:21:12 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-05 01:42:27 +00:00
"execution_count": 1,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-04 20:30:39 +00:00
"id": "ix_xt4X1_6F4",
2022-07-05 01:42:27 +00:00
"outputId": "704f8de0-05f7-466f-d00f-4cb903f7bf3c",
2022-07-04 20:30:39 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
}
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:30:39 +00:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2022-07-05 01:42:27 +00:00
"Requirement already satisfied: min-dalle in /usr/local/lib/python3.7/dist-packages (0.2.26)\n",
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.11.0+cu113)\n",
2022-07-05 01:25:38 +00:00
"Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (4.1.1)\n",
2022-07-04 22:40:32 +00:00
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (2.23.0)\n",
2022-07-05 01:42:27 +00:00
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.21.6)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2.10)\n",
2022-07-05 01:42:27 +00:00
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (1.24.3)\n",
2022-07-05 00:56:45 +00:00
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2022.6.15)\n",
2022-07-05 01:42:27 +00:00
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (3.0.4)\n",
"Tue Jul 5 01:37:07 2022 \n",
2022-07-04 20:30:39 +00:00
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
2022-07-05 01:25:38 +00:00
"| N/A 36C P0 27W / 250W | 0MiB / 16280MiB | 0% Default |\n",
2022-07-04 20:30:39 +00:00
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-05 00:56:45 +00:00
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-05 01:42:27 +00:00
"outputId": "ac8ebb27-6f0b-4f86-a9dc-e37af9fc6f69"
},
"outputs": [
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
"text": [
"initializing MinDalle\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
"initializing DalleBartEncoder\n",
"initializing DalleBartDecoder\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-04 20:06:14 +00:00
"from PIL import Image\n",
"from IPython.display import update_display\n",
"import numpy\n",
2022-07-04 21:27:02 +00:00
"from math import log2\n",
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
2022-07-05 01:42:27 +00:00
"Note: reduce the grid size if you run out of GPU memory. 4x4 has been tested to work on T4 and P100 (with intermediate_image_count = 1)\n",
"\n",
"A lower super-condition factor results in a wider variety of images but less agreement with the text"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-05 01:42:27 +00:00
"execution_count": 6,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-05 00:56:45 +00:00
"height": 819
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-05 01:42:27 +00:00
"outputId": "e2c0d589-2b6b-4f6d-8feb-5635cd7ede03"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-04 14:21:12 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
2022-07-05 01:42:27 +00:00
"<PIL.Image.Image image mode=RGB size=768x768 at 0x7FA1746345D0>"
2022-07-04 14:21:12 +00:00
],
2022-07-05 01:42:27 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwAAAAMACAIAAAAc45fZAAEAAElEQVR4nEz9S5NuW5YlBo0x5lx7f+7n3hsR+apUVRaZRT1UJatUIQnDZAYyQA0hM0k/gH/Ab6FHddQoE6YGfTr8AIwGHczADAkBqneRle+IuHHvOcf922vOOWhsD4E3TuOYn+Of78dac40n/4N/+L/5+Q+TF96yKPTzTdJjxVUbmd2IYUiU3q/9kou4nsRMrjnYrTUVKSzNjvDu+TpNRhQPXxPhFG1QsGds5fv7l59+s6r6suUViDGdc40Tmb6KfZnZiTbjQixjJSaAp+fNBXPVivnqsBW2YWE4Y1PX8+3bT0dfzwpgMjuDcuwLwvAU0HXBe+NUZl+ixysijGkb8BTAqP12LnRPE/DSCKCjm5IY7p42UZCAdAvTyEjZbWvcsMfZ++vrQ9NuMpjhBDCa7jlS7F3oyzj40IzwhI7BEieAtt/RVUgfJ56ObqYINwEBHLDrehzCVMnAQoc9sXq0AGEa9DW+jHAcLvppBWlPAyIIm8qrr0Um3RgpezsohEGRoly7Gdq2gUWJHjIjIPQMQJhUVl+PM3quNoMrICLMHhI1GbR7iKvLjVMhNzPHbXDsrkFrKWfe6ZYw7WmgYRNK99shot+37Vma8Hh0OY7ePGR3VXsIghgHYKGH44KMmZVZACNe47t/53/yn33/8/U6fvnZ49q9YteAJAkBCs20uxcZBIgRNQZgcoy+uKKVTTKU3Q6npZGeb2/fvB6o7cXdOJS9d6aL4YklpzjEs8dwDA+OKaz0jBhjkHp7m1AfR8NWLg/DMYgOXO/P1/NwXRN+Fs48pzbVVhJHTK0VV/XT3VaMTlYcKoLD+2aB0c7n88s3r9HVzgBTEzNAzraTwalIXZjnxUWFt1KNkxzCZJRnTDq63l9e8qpWhM3FxBiBQh8S+lqixT3GUEa4dYIgGDNIyqaFQR8RnDIwnN2SmahIthdoYUJZ021iAvU8H7NroARCw/v/2e7FcG8RjX4vYpguBXalkvYQ6pkxa3Rd758+xfW8HAmnRj2cmGuQTNcFjoPvm0EeUxSeFXGopwV1T5n7YmIzLoiG3CSiHQ0+398ex/K+Sr6aqZxuYLfWTCQnxKtmuweOCc0lESkYRPSA1PN9Mgq8QJJSLHFNY4LeFUkaX+r5cp7v789jSYGqUGjaRkVo9mZOrqOev1yll+Mff/7+PUpPNYmpC+R9xaQYQwZNiLtrZbprAmNpwh6pR4QjMQJ65t1tM6HEHtIUbEBje0DGrus8VNUNiCuoMgCM5+OSCTamx76XJ0AiaMMemySDBmdaEgkMWR63PRIaUfdPlBaGY1Da9cwIdk3MNImcacUURAThUFxT7zWp0DRcDZBBAlB3w6Diup7nGbV3c8TgCCAEE7LgsTFwjQXcqyuCBj3MTFI9NmxQ4nT9d98FgFBKCg7sQZAK2QOIRDfPM0V4unskkcSge+4V3fcmO4gkMRGaPZSr58j0TGQ0TPC+bIjodlCLptrmvR8QmOmP9c6VAWMGWFwwbJLenECEOwTANQR0b8g9ihA8oDwz9ljTz/PUVCMIBKExjGlYlDzAgNxzbxhN3AuhMENq3NPAUGiqSBmUAgNAbXf3sYSui+OJYNgDziDdFG3OHqAHEAKGHudP8pf+WfsvP539levqfnn97ifHmufXc/kCj9BCvAPPXd/F+kmvX6hLbwfWb8fLCf5cXw29DEmdV/0YjkOLL795rC/1tfkuJwHSe0aRu+enLy+nJyKNfmDl5K55a3+j45uOa80TjcKnxzeftn6cLxVg8ZCOy0g+E+njr8zjs6b1hhZMcI8BxTX707k+bX+WqDod3+TLW9dX98F8SBFTrqf78c13P+PaX95KW0bMwD7kXbPWumoy4qRLrPAxeHBdPU+3jcMozLfyDn/m8ZP10l/fWyWMBhpR3mhKaDvXAThQmINM87n78rxQ31R+VQFj5nfH60+v+B4/PsWjwXE2Nv2W9XKcv/1+vMNPva0xHeZ4hoxr9uvKl+EbCoJLP42XNH/UW0PncLNBibbwbbys92vSjcJoqCVW94r13vuV8RJRmFF345vHw0DTBAK6pl/OtWEAn44zisXNMIYzID2kHD39kufqubgmeMYjnK3uVvec6zC3NnbMML95eTneOaiOJ5oYNviUXvJcG086o3rcMY1Opiqebul86X6azjkmv7ter96fM3bNpyFg1FTgi+uhI6od/YZayOXodoqCt/j2enydN+ZvIP7s03fn7q9To4wVPMXuawXfa9p+Xec31Pv0xlb7gKB4d8H45uXA5MZORneFGG7Tz+v95cyDfS3sfk+fL+RaL7+a94zUWOqjUSveA4vnb3H9qp8dW64EJe+ecZ0PLUV7r1w9Fr1Ce/p51cvSC/pzds/14OMldPD1h3lHhLoZhSkG7HnE+ZtcP/az8Ix72RS7zeDeX7/9lEfOe9CwPC/nepav8UktYRYCu4aPl8dP8uX97a01M9cZAZg5vubIDLpWHjkAlEbzPGjrunYyToGRMe6FXcjIT1o9l7UJIDAbRwYd19QReVJbc8hvjF753XqZr+8T1dNnBJrM6e0jLFMrQmMT6bn6zEXH196nYoEVPozPNI/1k3ztL2+bpVVLmjFjfDmC7F4vWuw5ojHR/Xrk18tfPaf44DyjX4Wvgdb62fmYz+9fMVq10AuQ+OYJz+P1SK33fcmECJrwc1+z55uDj+wfMKj3b/L1dcW0fiwHkWhoZ/M49GPP0svPpF9cQxanJFF87jH4eGBl2OECI1I4wj/25Y0lPCI/13WqD83Lt4c4n98vJU71O2rXHD6O5KSC3nw8Pr1//+NCfzlyArz2PvKI0PS1EANKhGnw6joZB/SMMHYOPumRoy98NhkDyNEzAo0zH9/seHNAlywbEHrMYM8cYoCOFBGI2V5kuQ5mgi2L2N0hLeb0hDiYnuEAIYgvWv2+O0VNkBjDQxfFU1ENEB5HSBhTDVQ9D+owtmx3OM6IxHqbZ0SgDVlVNb2Cj1zrijeMMDTu5U0khOo6U2kgglAyVyQHG22QgAXaY5KjyCOyuw2PrbWWAmCxxwiGMaXIULXvA9aRiYHCNZW5FjgeSyANMeLUqt4jLkdSE7NRBySx3TMTK+T7TCvSk949r6Ejg8yaWuSK3NPJ2Jw4jtd4oK5hz1RK8AAeajwRECLDu01BNgZJbczBSEjCMf6KzcjXdfK9ygNNfEx0PXDIMeaKILyyOGnEaAaXOhVpNh3G5Vas13Xo6x55R8sgZMBwBM7IHo5GpiURALuawgoc5lMMOyMOLhvveIqU0S7PRCOUYs70degd7/y7//C/rL/4008vZx9aXTKWMIMvfZEgNONAiOjpHA3Q6aBOPwJ1CRRgVU911VipI+LTenx5v8irIRvgAAplZr8/d5JjbOKMM5WFsg3DjXG91V6MI86TaHnDMGpm7zIx8iH9lC8/XO/lNzsGFAwzlGTXvhY0M1ujOUIZ6p4hYwbu6d4rI5AnHV7vc5Fuu2vGBhkSuavsssjiwAotoE2TMMLuHnd3Zr6Ih4/32Yxp4OoWYALUCl/XFaSthmlRSsIug2wMJsRQLK9X48q5SNg9M9WidhjAT3Ds7i98mwne2xqxGJSrew2N6eRShFNw0yZnXDPbk5ACoVjQ1/3es2UOIYBGZNgN0g2i93RQj3WarhmSY48RUnUr81wh5q62Osht2ya0cpG9uwTbtnTEIWSzwia1e4BBT3u0IhUvWvvaF68p4D6XEKQe4Nve0xtGtUOildL2NT3x3py+0hgfcwxqa+hw2fCuq6cROhaPiWcPaMAHZKMxAQFGotb57/57/6vnz//820/nRcujEABlsDdmkDojpv0pc8qXdgPukdQGqSO4AnvvPY4I20CKiGB3JzQwRJILIc6QMoe02z0FmThS32o9e0a7wWmTNGhGpIOefYObH495EACM5qgxUsAQctEtAhjD7u4GBArip7W6u30NNCQ8hpIrElfthAZjBU1GigYMcUbidE/VZGZmYHh1H6keNwxDkiIivK+ZaVIIoZHrAEoECAzgA
2022-06-28 00:58:17 +00:00
},
2022-07-04 14:21:12 +00:00
"metadata": {}
2022-06-29 17:55:23 +00:00
},
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-05 01:42:27 +00:00
"CPU times: user 34.3 s, sys: 570 ms, total: 34.9 s\n",
"Wall time: 34.6 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-04 20:06:14 +00:00
"grid_size = 3 #@param {type:\"integer\"}\n",
2022-07-04 12:05:55 +00:00
"seed = -1 #@param {type:\"integer\"}\n",
"intermediate_image_count = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\"] {type:\"raw\"}\n",
2022-07-05 01:42:27 +00:00
"supercondition_factor = 16 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n",
2022-07-04 20:06:14 +00:00
"\n",
"image_stream = model.generate_image_stream(\n",
2022-07-05 01:25:38 +00:00
" text=text,\n",
" seed=seed,\n",
" grid_size=grid_size,\n",
" log2_mid_count=log2(intermediate_image_count),\n",
" log2_supercondition_factor=log2(supercondition_factor)\n",
")\n",
"\n",
2022-07-05 00:56:45 +00:00
"image_shape = [256 * grid_size] * 2 + [3]\n",
2022-07-05 00:08:31 +00:00
"zero_image = numpy.zeros(image_shape, dtype=numpy.uint8)\n",
"display(Image.fromarray(zero_image), display_id=1)\n",
"\n",
"for image in image_stream:\n",
" update_display(image, display_id=1)"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-04 14:21:12 +00:00
"provenance": [],
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-04 14:21:12 +00:00
}