2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-07 12:21:20 +00:00
"colab_type": "text",
"id": "view-in-github"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-05 21:22:20 +00:00
"execution_count": 1,
2022-06-28 00:58:17 +00:00
"metadata": {
2022-06-30 14:02:08 +00:00
"cellView": "code",
2022-07-04 20:30:39 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
2022-07-05 03:17:31 +00:00
},
"id": "ix_xt4X1_6F4",
2022-07-05 21:22:20 +00:00
"outputId": "10a5df30-a55e-4180-dee9-844cf0aa1b69"
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:30:39 +00:00
"outputs": [
{
2022-07-05 21:22:20 +00:00
"name": "stdout",
2022-07-07 12:21:20 +00:00
"output_type": "stream",
2022-07-04 20:30:39 +00:00
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2022-07-05 21:22:20 +00:00
"Collecting min-dalle==0.2.28\n",
" Using cached min-dalle-0.2.28.tar.gz (11 kB)\n",
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.2.28) (1.11.0+cu113)\n",
"Requirement already satisfied: typing_extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.2.28) (4.1.1)\n",
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.2.28) (1.21.6)\n",
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.2.28) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.2.28) (2.23.0)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.2.28) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.2.28) (2022.6.15)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.2.28) (3.0.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.2.28) (2.10)\n",
2022-07-05 09:52:30 +00:00
"Building wheels for collected packages: min-dalle\n",
" Building wheel for min-dalle (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
2022-07-05 21:22:20 +00:00
" Created wheel for min-dalle: filename=min_dalle-0.2.28-py3-none-any.whl size=11276 sha256=2950cee9d9859f67a8d19fb66ff3f8a6d48e8aa99527f854f6070dc9b104e07f\n",
" Stored in directory: /root/.cache/pip/wheels/c6/93/46/59099a6db4c0c4e962a4b02ea54705a5c4603ae365923f42b9\n",
2022-07-05 09:52:30 +00:00
"Successfully built min-dalle\n",
"Installing collected packages: min-dalle\n",
2022-07-05 21:22:20 +00:00
"Successfully installed min-dalle-0.2.28\n",
"Tue Jul 5 21:11:08 2022 \n",
2022-07-04 20:30:39 +00:00
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
2022-07-05 21:22:20 +00:00
"| N/A 35C P0 27W / 250W | 0MiB / 16280MiB | 0% Default |\n",
2022-07-04 20:30:39 +00:00
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-07 12:21:20 +00:00
"! nvidia-smi\n",
"! pip install min-dalle"
2022-06-28 00:58:17 +00:00
]
},
2022-06-29 13:43:42 +00:00
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
2022-06-29 13:43:42 +00:00
"source": [
2022-07-01 00:55:42 +00:00
"### Load Model"
2022-06-30 15:25:24 +00:00
]
2022-06-29 13:43:42 +00:00
},
{
"cell_type": "code",
2022-07-05 21:22:20 +00:00
"execution_count": 2,
2022-06-29 13:43:42 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-05 21:22:20 +00:00
"outputId": "61c8b4f8-8a3d-4a97-f881-3b7e309e696f"
2022-06-29 13:43:42 +00:00
},
"outputs": [
{
2022-07-05 21:22:20 +00:00
"name": "stdout",
2022-07-07 12:21:20 +00:00
"output_type": "stream",
2022-06-29 13:43:42 +00:00
"text": [
2022-07-02 13:04:13 +00:00
"initializing MinDalle\n",
2022-07-05 21:22:20 +00:00
"downloading tokenizer params\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
2022-07-05 21:22:20 +00:00
"downloading encoder params\n",
2022-07-02 13:04:13 +00:00
"initializing DalleBartEncoder\n",
2022-07-05 21:22:20 +00:00
"downloading decoder params\n",
2022-07-02 13:04:13 +00:00
"initializing DalleBartDecoder\n",
2022-07-05 21:22:20 +00:00
"downloading detokenizer params\n",
2022-06-30 19:17:35 +00:00
"initializing VQGanDetokenizer\n"
2022-06-29 13:43:42 +00:00
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-05 15:53:27 +00:00
"from IPython.display import display, update_display\n",
2022-07-04 21:27:02 +00:00
"from math import log2\n",
2022-07-07 14:28:07 +00:00
"import torch\n",
2022-07-01 23:44:24 +00:00
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
2022-07-07 14:28:07 +00:00
"model = MinDalle(\n",
" dtype=torch.float32,\n",
" is_mega=True, \n",
" is_reusable=True\n",
")"
2022-06-29 13:43:42 +00:00
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
2022-07-05 01:42:27 +00:00
"\n",
2022-07-05 13:37:52 +00:00
"- `grid_size` Size of the image grid. 3x3 works best when displaying intermediate outputs. 4x4 has been tested to work on T4 and P100 with intermediate outputs off\n",
2022-07-05 03:01:13 +00:00
"\n",
2022-07-05 13:37:52 +00:00
"- `intermediate_outputs` Whether to show intermediate output. Adds a small delay and increases memory usage.\n",
2022-07-05 03:01:13 +00:00
"\n",
2022-07-05 21:22:20 +00:00
"- `supercondition_factor` Higher values result in better agreement with the text but a narrower variety of generated images\n",
"\n",
"- `top_k` Each image token is sampled from the top $k$ most probable tokens"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-05 21:22:20 +00:00
"execution_count": 7,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-05 00:56:45 +00:00
"height": 819
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-05 21:22:20 +00:00
"outputId": "7045cbde-7ebe-4305-b7a8-8a76d7c68251"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
"data": {
2022-07-07 12:21:20 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwAAAAMACAIAAAAc45fZAAEAAElEQVR4nIz9adRk2Vnfif6eZw/nnBjeKYeaJ5VKKgkkBAJZQlgIAwab0RjLzJeml2nfxm6Mr/u2G9Meub4X29deLEPbmIY2xoAxNhgsjJnEJKHJmocqlWqeMiuzMvMdIuIMe3juh4hEYnnoez7kWhEZ74l3xbtj7//+T1v48X/O0z2jI1aisjnFC1KplSqIUpUiNJ6aqYYaDhIUB4rPOIcEckYVK7uXmWEFE2JDqainFqqgnuGM1lEyBq6lCChOyAVx2IgTTMgeQAuiSAQDAXCVsZAFTVjGRdSDI4EYTkk9DryQRrqWMeEqjIinZKhYQhyN0sP+gjGTC/46tn7TwTOfnx+TZ5/du/VlB+3y9gdfZ+rz0Jw71y5cSDwWZ3F9ZXV6fPnZx55/4pPPXVtf+fyvu/dVrz///Lue+tpvvQK89G/f2n7Gt3/0g2etH26Rssfxy/bLV73p4p3xxUceXT39scd+9nuffAJeDd/z53iy8Ld+nP/i9boL/KOfePCxG6dDXX/uq++2Mzm65c59u/snfuRH/9I/+BcP/cJvvunrflxbrgwQ+Jw/wjd/80u/7o/d+ehDex9/90d+/33Pv+yBxRe86Qvu++LvqTfKDbHV5rGSry8vdFdfODlLub/R/5t/+fMnzz6+FD76EU5vMBr7M6pyvIJb4IXIt/w57v0Mjo7IpywapsK6pxYoVKMIqgiYUSvqKRWDJJjhKiiqVKUKBmRqQQVXsUwSxIPDGyLgKRkVLJGNXMGQigkhYBVxWKYqzjGtaJScEaU41FMrzgBywQk5gQNFCk6YHFVBUcF7CrgGZyTP6o7Xf+mNayvzm3RS8/5Mh7N1trK/iCUV37RTdZataxuv0qc+eu+k9DlnwiLMfU2N5Cl4qdHV3Da+z2UtNQ8yU3cQpyLuLGtQ71ycanbiEV2vT88fdjmlPhenM0ycj+Lduu/3YhvqgON4Ki0xWvUuF+cbP8/jEIPPaoPAWA+iZ9r4tlml6oILzo/j6DWE0Fw7u3Zxry2lH1BnvlMPZRwn8+oKrSviZMStci5TOmqjFsuiPobibJxKLkl8MNXeypDJpe4tlrUUdWZY3eS9rvG+HETyuFrMZG/WRZtCG7KzaVO70hzMY0016CJIda22YSPDOCabTs6kbV1jx6tpPL4Ux7O2GRfLfcn6zCMnTz7x/IU7L9z9wKvPH/mpOefy2dXnzy5dWa/Wq824vuulB3sXjs7ffih+1sQw4k+ur6Zwwbk8rhO2mtappNPoXRqnktdSfTdTrPZ5zGd9aGZoTeThbIquVS9ZQ6kVk5KT+lAk9GdnB3utWe6l5o0Ec4Ij1KSuZmld9UEyrKfiRTs1R0l4UyGbOGdWSzFHqHloZjJlcq6BxrkwIanmNPb781bKWFzpRwt0NWcfKM5hwVmKjavY8ThZpZXSKhWXVLQCmCGiVC2p75Z+nGpJFv3Mq5tMCpbTOJ9FLeOmjMPk9roFeRIpFruS1TvwZlZXq17cwqIUkQv7R3+v/1w++CS1w3sibFbUDAkpOI8KuTIW5pFpIHjIjAkXwWOGJIJHW9KAggh5pMCQ6cAmmgYJSGCo4Ggdw4rWUTNTJs4ZJrpAzfgOrbtFZJaRzPDBdjm8NP3S4kg++rCsNi/hvs/i3JvghKklrxnOmCIUBmXhmDKdwwvTSKqYMsxQJQmx4aSna9BIKqwTXcR7+jV7HVZYbyiFRcuU6Ac6pSTMER1mmFAncmWTcBNxQh1FSEoRJBKVzRnRMZ1BJSvVQ8VndPv5jDgoAxj9hAh1g1OCglAEMzLElvU1DjvWZ6iSHCMAreEcOKRQM1QUViPRk3tChxcMMhQDj/NsbrA3Y7MiOkZl8ii0gEcjdSSNRCVXpkKAcQOe6BChClYogkaGY2aOWihGiuAowiwwjTiPDaQRcfQF5wlGEeKcOlHzbhRtRprA8TVQlg1AVXLGFNcyHLPfsdmAUBv6Ap4WDIIjD7QNacKEseAKVnAtUagFlFzIGaBf4wNdQ6pURzUMvMcLZWLK5AHXUis+MmvxM9JFz2aJQjvSzllvaI9oIPWUCfMIFCU4aiEHnMMlyLgW35EMJkRIBYWScOA8zYze0ISbKIIFzDCIjiERW6SiCop4nMMgVWgQiAVNTJH5giyUHnFkdn/yCKnSLOg8/YCMIKDodsmLjANNxAml0gZCRKCM1AZRYiFnguA9kjkfWRWAcw0+0kwXwvO39uGBz33wnvvPzffm773ykXZzGNJR9+zmXLMsy81Ff3rc5Rev5Jl4c3r9xF5+x96defnBh69sgUv3cXvFW44+uro2rvRSaqXY3eeH11xt7lp+1e988K3/59948gSAP/tlTM/yt37lU4jHwwPw0M2H77nKm//kw3/j22bdS8ve5927uXajCw+MJ8/q0wW+6bd/9h+9CIzEW5jg/R/E33j0Leeu3Gvn+7Oz33pyfOcHxqd/9Tf/7hf/f26/8+IjT31oofc0zd39OpfVR46fe+Fd7/uo5GdPTwke8TSBpTCNVAHgBZgHmhdpjJnBHI2s1rgZs0oRpkxU1DFNmCKGbOc+mEVSwiYqVDADhxm6RSRgCXO0HRYoFUuoZ6o4oRS8EgLzjpMBqfiECdWhRhVcIGf8DCbUYYL3VMUEUcaCb3ETVfFznKITm0RocZ4KVRgFr4hSlOg4ePaZfNuwuX7IqO3s0njmuvau+eKojL3bHDNFjcsYTTnNQ7JyG/MXuLqO1ki4yxa3eB5z105FFmadKwdrW8VytS23z4/esNn7eHnhzPVO5p1DSFUqXlcpzxbNciorfI5y5NvG5FTt2Wk413YXk9s0+UYeLcajdu9Wi08Ml7NXzVMT5VyxZ1256upthwe3a3epz8dT32rbKJZTG2wWw5V+vVg2bZa108nVC2E2LyHGdNXVsdB43/jsa03eSozn9vZuc+HKeCKIiJoR22bhnRe3qpMTYi3ZmDdt9HKWp7G6Zi+mkkJw13RaLPduTMcxHGqQnJOYLWba9nLrhez2l9bm9tLjJ8+dXr68ytdPNscvcLw6ybnLehaWtY4HYfN5L7+jWnt8bfXis5d/6Zffc/3q8/csbr/z7qPQ8uzzN5659NjqLDEA7OO7vTv2X3J0/8sfWCzj7Q8ePvjGi+cWyzS7U9u4WavMPe5wtV5rIjYLFRNXhs2qZmbLhSrraco57c/3W5VjGbSCEV2FkFT7cVrM21nmTGVVx4Owf4FlLuWKW2N+oSG63PbjtZim6GbtbLFOvU49U6tdUHNKX4tE0SrZ6lhymrSZ+daoU8qVBPOm3e+bTUin6RQNh+7CTN3VeFxwYbLgNA7pJORxJktd3DrGYzsb3BhqG0HUBhI4RgkhpjxW02bp20reTNnpILII8XBor/i+92nZLM/5uVr3Qj6tSHDqa40rBqlNbHw369NZH2bXz3pubGgPcYZGTjd0S+aevqeMqDJlgjAPjCNhQfQMG9qI72g7UmWaECgZH1GjVPycWcvFyIvXiA0oKAnalug5PSM2eGWKzBztnEYpI0GwwDQQHF2LOEjsH9/H2/70fbddvfL40X3nLu3ft653r/TZ02Fv1TpKh2sIlVxYzjDFDQRBCl1DhRqZO0Jk9Jhnbw/vSIVYaRza0ffsz9BCyrTCLDJtaPa4cCfrFf0GyQAmeCN5OuWgZXNC7ikQPVGYCj6yWtHtA0TDNoQ53WK36KowJmKgGj6iwvyIbKRT8hpzmCIFjFnDes18yVAJc2wgzpkvqAlJVAPBV6wiE9m461Yuv8isI60pAVNsQj2+Zdww3yMJswPyGX7G7ICcYQOC83hoPHXEd1zc59Jl5gdMAxLB4YxSmLdMifkSEbTiKn4LX4xcCXM04oQOVsLFW8gwnNJ4JogtAjJQPXfeyrWrHF0kjxiIx1VwxBn9yPIAMzpPzZTI4YJhJGWajlxoIqHijdpwtM/pNTBKAUMdChI5f471MfMl04B6wpZnmZh1DBOpsDzCCZseGnKmCVRwHfWqRzKlxzu0ZzGRB0whIRU1UGS7864IWKYaY
2022-06-28 00:58:17 +00:00
"text/plain": [
2022-07-05 21:22:20 +00:00
"<PIL.Image.Image image mode=RGB size=768x768 at 0x7EFDA43F0710>"
2022-07-07 12:21:20 +00:00
]
2022-06-28 00:58:17 +00:00
},
2022-07-07 12:21:20 +00:00
"metadata": {},
"output_type": "display_data"
2022-06-29 17:55:23 +00:00
},
{
2022-07-05 21:22:20 +00:00
"name": "stdout",
2022-07-07 12:21:20 +00:00
"output_type": "stream",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-05 21:22:20 +00:00
"CPU times: user 34.1 s, sys: 448 ms, total: 34.6 s\n",
"Wall time: 34.4 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-05 13:37:52 +00:00
"intermediate_outputs = True #@param {type:\"boolean\"}\n",
2022-07-04 20:06:14 +00:00
"grid_size = 3 #@param {type:\"integer\"}\n",
2022-07-05 19:07:32 +00:00
"supercondition_factor = 16 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n",
2022-07-05 21:22:20 +00:00
"top_k = 64 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\", \"128\", \"256\", \"512\", \"1024\"] {type:\"raw\"}\n",
2022-07-05 13:37:52 +00:00
"log2_mid_count = 3 if intermediate_outputs else 0\n",
2022-07-04 20:06:14 +00:00
"\n",
2022-07-05 00:06:28 +00:00
"image_stream = model.generate_image_stream(\n",
2022-07-05 01:25:38 +00:00
" text=text,\n",
2022-07-05 13:37:52 +00:00
" seed=-1,\n",
2022-07-05 01:25:38 +00:00
" grid_size=grid_size,\n",
2022-07-05 13:37:52 +00:00
" log2_mid_count=log2_mid_count,\n",
2022-07-05 21:22:20 +00:00
" log2_k=int(log2(top_k)),\n",
2022-07-05 01:25:38 +00:00
" log2_supercondition_factor=log2(supercondition_factor)\n",
2022-07-04 21:44:30 +00:00
")\n",
"\n",
2022-07-05 02:52:00 +00:00
"is_first = True\n",
2022-07-05 00:06:28 +00:00
"for image in image_stream:\n",
2022-07-05 02:52:00 +00:00
" display_image = display if is_first else update_display\n",
" display_image(image, display_id=1)\n",
" is_first = False"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-07-07 12:21:20 +00:00
"include_colab_link": true,
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-07 12:21:20 +00:00
"provenance": []
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-07 12:21:20 +00:00
}