min-dalle-test/min_dalle.ipynb

220 lines
1.4 MiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-04 14:21:12 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-05 01:42:27 +00:00
"execution_count": 1,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-04 20:30:39 +00:00
"id": "ix_xt4X1_6F4",
2022-07-05 02:52:00 +00:00
"outputId": "a1821c1c-b5b3-4f6f-a53d-abf9c716f868",
2022-07-04 20:30:39 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
}
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:30:39 +00:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2022-07-05 01:42:27 +00:00
"Requirement already satisfied: min-dalle in /usr/local/lib/python3.7/dist-packages (0.2.26)\n",
2022-07-05 02:52:00 +00:00
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (2.23.0)\n",
2022-07-05 01:42:27 +00:00
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.11.0+cu113)\n",
2022-07-04 22:40:32 +00:00
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (7.1.2)\n",
2022-07-05 02:52:00 +00:00
"Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (4.1.1)\n",
2022-07-05 01:42:27 +00:00
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.21.6)\n",
2022-07-05 02:52:00 +00:00
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2022.6.15)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2.10)\n",
2022-07-05 01:42:27 +00:00
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (1.24.3)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (3.0.4)\n",
2022-07-05 02:52:00 +00:00
"Tue Jul 5 02:47:56 2022 \n",
2022-07-04 20:30:39 +00:00
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
2022-07-05 02:52:00 +00:00
"| N/A 39C P0 29W / 250W | 0MiB / 16280MiB | 0% Default |\n",
2022-07-04 20:30:39 +00:00
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-05 00:56:45 +00:00
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-05 02:52:00 +00:00
"outputId": "101b43dc-196f-4d4e-bb47-3674732c0b25"
},
"outputs": [
{
2022-07-05 02:52:00 +00:00
"metadata": {
"tags": null
},
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-07-05 02:52:00 +00:00
"output_type": "stream",
"text": [
"initializing MinDalle\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
"initializing DalleBartEncoder\n",
"initializing DalleBartDecoder\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-04 20:06:14 +00:00
"from PIL import Image\n",
"from IPython.display import update_display\n",
"import numpy\n",
2022-07-04 21:27:02 +00:00
"from math import log2\n",
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
2022-07-05 01:42:27 +00:00
"Note: reduce the grid size if you run out of GPU memory. 4x4 has been tested to work on T4 and P100 (with intermediate_image_count = 1)\n",
"\n",
"A lower super-condition factor results in a wider variety of images but less agreement with the text"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-05 02:52:00 +00:00
"execution_count": 5,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-05 00:56:45 +00:00
"height": 819
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-05 02:52:00 +00:00
"outputId": "7829c121-02d0-4c52-ac7d-a077bbff9bcb"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-04 14:21:12 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
2022-07-05 02:52:00 +00:00
"<PIL.Image.Image image mode=RGB size=768x768 at 0x7F742771C350>"
2022-07-04 14:21:12 +00:00
],
2022-07-05 02:52:00 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwAAAAMACAIAAAAc45fZAAEAAElEQVR4nFz9ecxtWXYfhv3WWnvvc+7wzW+s96rq1TwPPQ/sgWyySYrUSIqR5diKE8hKFCJOAMMwIgUZjAQGIiOAgQAJIxm2JZGy5JCUKHESyZ7Y1WNVd1d11zy+ef7m795zzt5rrfyxz9dt5/1RKHzD/e4999y9f/s3LcKDnwcCdAJfIkQMS5hDAC0AoTg4ogCNQxjZQQYrcAIIRFCBFkRDZJiDBOagABCoIDiyIzCMgYDimATAAAY7uh4cYAXMAJAV5EgCGExQDExYAqTgJUgAhwdkg0YEQ1BkhWV0S4QEBSxgEhEZBYiEIcMDGCAHRZiCHSkgAMrIABwaYT2aAhDUIYIMeEQ0sMIVqhh6IACMQpgEkMMIkbHsQAFUwAxjmIEUTYI4lFEAJhSGZcQMEJwQEgqAgOBggyvcoQojuMAD2voyGcGhCiWwgRjKcAUXpAR2KEMdLFCBL5EcxUEMilCGEhoHO8wAQh7gDDNwADlgQEAgOMMUZnAHJ1iGODgABIsgBhEAWIckAFAMiDABA4FhhiDjfVK/SAIGQADAgpyRAuBwgjG0gxiagDwAEeaQBGJ4jygoBQxYQCYQIwKRURSB4A4FyECCEoAOohCBOZzhEc6AAgXkyAphKFAcEsAKEhQgENRgBQJwRCmICXll/bf+q92LBT2jiYChFGhBYmhGCFCDAiwghylii6GDGSRAHYFRerDAZbxWTrAMd7jAMsTgBGY4oIAI3DF0SBElQxUxYVCQgxhmUAMRDGCGFuQCCdACYbijOITBAjcwoxsgEW7weqcZVMECKIpDFSIA4A4zZCAy3MGCXkEOEcABhhsIMEWKyBmg8bvuEIIpmOEEOAggggOuCDI+PVNA4AAAdQBwAxxEMAUALUgBxREExREIxRAYQsgFIoDBMV5AAszAVG8imI/3EgBVpIA8IAqcQISUoAomUIQriCD1OkeQggwuiAwroAQYmGAEzxAD1XsYMIIECDD0aBI0w+rLJxADhYXNOArlohCBOryAhEzhaCapaGEK5iZMRaFD38w4D+rkbUxWzIlBlEsJEti1z0ViDESmTlRExIwjU1Fz4uxmZsGZXIWJAjtAYDUnkDpKGdKEh6Fw/TaRG3ldfVnIofXWzSW1acglSHRnA7kWqEEZLcMyiCZ5ft+f/JIIKJN5DimVoQcjxMjiMHOWYi7gFEQECqaQGHBiM0/MJauIhwghcWGi6CQEc5CTUyAnLdkSQhYapGXEaACbKSFNJUZxObG55UtKwUJM2Hnv9jtfW9x8/RNP3fOpT37k3H33cZifWJtTK+9e3//RD16/eunNi2++9sKXQR2u4P//39mEX/+PzseFX7509aEHt1ZPzS6cP3Pm7KkO+eaNnUlMlz64vb55qpmtPPDIk8Th7IMP7e0d2jIZhyY1FsR14gyYSQJH3j/oDu/euv7arffev/7etRuql565794nn336wiPTex4+hSm6riyv7mW968NO6Q5XWim9GrJmOtgddu+Qh9S2OauIQII6rKhpUbLBAct3KedJoGa+ItRYiATuSk9E7WrZ207vvXStXd+4/6mPLFcftruHe/35f/yP//s//dL3by6260s+McVnPvXxBy48eOHC+tamTNOZpvFm3nDjMbZFchAsu8yUIiV2DhQUWcJKahzoD48OrD9SWjDtLfrdo+XlK1euXHrz0vs/wu995390bX/9f/Xk0x9Z37TpqcmwtR7u7u5/cOniO+/cfutl/PBNXF7gCDjJeORxPPbQ2rNPrD394D0nzsmp1dnWRjJ474QYjGcaZMjdYLl0R8ujQ80lZy2lzznDC5wJkODmPnR5UCFwSMnJu0G1Uw7CTMRcNJvTkLOEGYubdyga0qS4ujJAIfTZl9BMTO5iDJFofW6SZOah9835GcLTv4rA6HsEQR4AQxSYA4pkyIzsMII4isMBz3BAGCIwgQhKBjLcwBFuMII7mADHJIEIVuACI1AA2bjmwhATyECOzAChKDhDFQoIwwE3kMAVKGCABd0AjigKAEZgRl6gcWQgTpEBcZhDGG6QgDbBHC7gAFP4ALdxK4KDGA7AQBkhggjLDtSgZLDDgMDIHdjhBA9AhGXQ8fOXiAgQoTBYoBmsI5phAWHc+5mAghjghH4YoZhUPMTQDCiIwe34/B2AjzArCcihARyhA7jAFM4QhhpAcAMDATADDBlAgjoCoI5G0A9gRqj7FmMgwMEDoBgIKQAMYZhACzhDBM5QAAwSkEHq5nSMnBBhABkMI2jQDDCyopmAGEVBABMkgBSFAIYEeAYpvAABrkDEUNAEkEECbAAMaQ6ncc+uCNgVxdFW3MnghNIjDAAgCbnABMOAKDBFiBDG8hASAR53zbptm8IczOAAAJwQBCj4v/wz9BmeEROWQOMIhFwAgxOKQwRU0bkEQckZ7pCEwNEp9wOIoA4uI0TIDk6IBM2wDAWCIDtCADmGcrxPO4hAjGwAgQF19AWJ4AAZBgcYQeAYcaoxJEJovDjKiIzBxmNGxR+JMBSowx0hgQw5wxwUR0hBjt4QBOQghxqIQQABAgyGIEC9aTFishRAGMGZMHIBEYhQfHwCEmAGrjchYAoCYDCDlfE29gIiqCEEOEZoVe+revBgB3x86wUwG4EgGIGRDV5xiSIK1BEDOKAUgBAYXv8oIya0DRZLuCESyKCOyFAHGI2AgTKwQ5isWBPFmPq+BJE2cWdK5sIhxMCGfigSCOoiCgBBlgUpNtOAoy5HUnKHExEq4FJ3uDepHvIQidypFC/ZYpAY6GAYyNA2KaRgWbUYHAwnNnd3oV5JQpolOuxyhMFBICcUA0AKE2YWy+rCHMDEXNSLGhNLkkUuppRS/Y4Ng4YgRT3nAhfPFJokQZloiIF5uPe3/8bELWkhkkGVhFOMbBi4j+zqPsR2wtNVFSXtwkAcyEFEhuBADERegjDglpg9UY5CMCnMTcHgrGBJHnouC5NE85mztXmB0BX31G5M19anayk2e5O8LElu/fDwtd/+tU889mu/+LMNx/UT1IR28e6Vf/HV7/+zV97Z2ZkcXLx48ZW8/B9szCeB5x8+9ciHzz/3oQcunD+Fo5CmR+ubSYftJrWAKnkZep5MVk8+zM2pnKWNIYbYNvPF4vakidIGNe61Kz1K35SiIbHzRBimw17x3Rs3L7715g9eefPq+9t7d3fX1/HMh059+hNPPfPoxnSiTNod+rDcSWmIOLI8mPfLI8pL9HYi0sTKoeqUySCHRfKgzs4wqA3i+wHCKIQQgkViYUVoPM2kGw55c/d2F9ut5sz9R7r2h7/77t//jf+qvurpNP7UZz/xic985N6zJ2ZtkhjdDyepJQlClq3Ly17z0vqOfGhmkzZIq2gkSUyT1ZXV9a1cGmJvGjTt0C9vdYc7iuWw2F/2Qzcc3Lx1/eat6+++duNb39bvfjBe6i88lf7qX/1LTz+PE06qR91hf7B3Y3e5e+fg6PKtvVfewjsXceMGSNEyzm3gkSdw7vHzjz65+ZGn7j+zydNUt7dWlQa1vl90eaFFNfeOwax361ACu2fDkHNxLooUk6sbrFcPlNgoY8isrh0xO1jiWilOOGJWUNPnQaihwpmPMjr2rk0TtujUOCVC8XBUzLKnZBPCx/4W0KErCAwt49EtRpQCAZwQAkoFJQQI8gD68ToVkQJygWbAURxMCGFc3eqhkCIYEAGH48WOYRiRhxrcIIIQURw+AAolEGA+HgQJ6DuQgABmiEAVZmCGCczgA0AjwnAFNyAH6u/KeJ6LEVnhGV5gdcs3ECMQSkZRsIzrMjPMUDKYUAAHgqPk8bRtxwwHaFyv1RAiYkIp8Aw2qIDr6ZpBDgKGY5ZLZMR2WoEUwQgB49m9XlUWgMcTLYBSEBJCRCnwAWQoPKKrGGGKUuAYL4gIRKAGVVDFlAFCKAMMSAFGMB6BJgQBMAcRgqBXaAYRnMEEieMtkcv4kuuNwQJTZAUFMEEVVLfGgCBYZoSIJsIcpYAAEGJCEzEoyhKEccMD0LQYlmCCVtKoYiYeH
2022-06-28 00:58:17 +00:00
},
2022-07-04 14:21:12 +00:00
"metadata": {}
2022-06-29 17:55:23 +00:00
},
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-05 02:52:00 +00:00
"CPU times: user 34.2 s, sys: 593 ms, total: 34.8 s\n",
"Wall time: 35 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-04 20:06:14 +00:00
"grid_size = 3 #@param {type:\"integer\"}\n",
2022-07-04 12:05:55 +00:00
"seed = -1 #@param {type:\"integer\"}\n",
"intermediate_image_count = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\"] {type:\"raw\"}\n",
2022-07-05 01:42:27 +00:00
"supercondition_factor = 16 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n",
2022-07-04 20:06:14 +00:00
"\n",
"image_stream = model.generate_image_stream(\n",
2022-07-05 01:25:38 +00:00
" text=text,\n",
" seed=seed,\n",
" grid_size=grid_size,\n",
" log2_mid_count=log2(intermediate_image_count),\n",
" log2_supercondition_factor=log2(supercondition_factor)\n",
")\n",
"\n",
2022-07-05 02:52:00 +00:00
"is_first = True\n",
"for image in image_stream:\n",
2022-07-05 02:52:00 +00:00
" display_image = display if is_first else update_display\n",
" display_image(image, display_id=1)\n",
" is_first = False"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-04 14:21:12 +00:00
"provenance": [],
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-04 14:21:12 +00:00
}