min-dalle-test/min_dalle.ipynb

241 lines
4.0 MiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-09 14:35:17 +00:00
"colab_type": "text",
"id": "view-in-github"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-09 14:30:12 +00:00
"execution_count": 2,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-04 20:30:39 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
2022-07-05 03:17:31 +00:00
},
"id": "ix_xt4X1_6F4",
2022-07-09 14:30:12 +00:00
"outputId": "3381aedd-4f25-4b73-b0d3-9d3675f86f02"
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:30:39 +00:00
"outputs": [
{
2022-07-09 14:30:12 +00:00
"name": "stdout",
2022-07-09 14:35:17 +00:00
"output_type": "stream",
2022-07-04 20:30:39 +00:00
"text": [
2022-07-09 14:30:12 +00:00
"Sat Jul 9 14:20:18 2022 \n",
2022-07-04 20:30:39 +00:00
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
2022-07-09 14:30:12 +00:00
"| N/A 37C P0 26W / 250W | 0MiB / 16280MiB | 0% Default |\n",
2022-07-04 20:30:39 +00:00
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
2022-07-09 14:30:12 +00:00
"+-----------------------------------------------------------------------------+\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting min-dalle==0.3.11\n",
" Downloading min-dalle-0.3.11.tar.gz (10 kB)\n",
"Requirement already satisfied: torch>=1.11 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.11) (1.11.0+cu113)\n",
"Requirement already satisfied: typing_extensions>=4.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.11) (4.1.1)\n",
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.11) (1.21.6)\n",
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.11) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.11) (2.23.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.11) (2022.6.15)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.11) (1.24.3)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.11) (3.0.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.11) (2.10)\n",
"Building wheels for collected packages: min-dalle\n",
" Building wheel for min-dalle (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for min-dalle: filename=min_dalle-0.3.11-py3-none-any.whl size=10554 sha256=72d50f11369edf356faffae4575a19840b7cf2fa92319ff34f836d566efabf65\n",
" Stored in directory: /root/.cache/pip/wheels/08/7c/9e/e87a42b400d85af27f9f5fda5c834262f4f20d105d91f1ffc0\n",
"Successfully built min-dalle\n",
"Installing collected packages: min-dalle\n",
" Attempting uninstall: min-dalle\n",
" Found existing installation: min-dalle 0.3.10\n",
" Uninstalling min-dalle-0.3.10:\n",
" Successfully uninstalled min-dalle-0.3.10\n",
"Successfully installed min-dalle-0.3.11\n"
2022-07-04 20:30:39 +00:00
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-07 12:21:20 +00:00
"! nvidia-smi\n",
2022-07-09 18:44:13 +00:00
"! pip install min-dalle"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
2022-07-09 14:35:17 +00:00
"### Load Model\n",
2022-07-10 17:23:42 +00:00
"`float32` is faster than `float16` but uses more GPU memory. Change the `grid_size` to 3 or less if using `float32`."
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-09 14:30:12 +00:00
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-09 14:30:12 +00:00
"outputId": "d2f82d75-7178-4813-d4f1-0c2a09d9a550"
},
"outputs": [
{
2022-07-09 14:30:12 +00:00
"name": "stdout",
2022-07-09 14:35:17 +00:00
"output_type": "stream",
"text": [
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
"initializing DalleBartEncoder\n",
"initializing DalleBartDecoder\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-10 17:23:42 +00:00
"dtype = \"float32\" #@param [\"float32\", \"float16\", \"bfloat16\"]\n",
2022-07-05 15:53:27 +00:00
"from IPython.display import display, update_display\n",
2022-07-04 21:27:02 +00:00
"from math import log2\n",
2022-07-07 14:28:07 +00:00
"import torch\n",
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
2022-07-07 14:28:07 +00:00
"model = MinDalle(\n",
2022-07-10 17:23:42 +00:00
" dtype=getattr(torch, dtype),\n",
2022-07-07 14:28:07 +00:00
" is_mega=True, \n",
" is_reusable=True\n",
")"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
2022-07-05 01:42:27 +00:00
"\n",
2022-07-09 14:30:12 +00:00
"- `grid_size` Size of the image grid. Reduce this if you run out of GPU memory.\n",
2022-07-05 03:01:13 +00:00
"\n",
2022-07-05 13:37:52 +00:00
"- `intermediate_outputs` Whether to show intermediate output. Adds a small delay and increases memory usage.\n",
2022-07-05 03:01:13 +00:00
"\n",
2022-07-05 21:22:20 +00:00
"- `supercondition_factor` Higher values result in better agreement with the text but a narrower variety of generated images\n",
"\n",
"- `top_k` Each image token is sampled from the top $k$ most probable tokens"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-09 14:30:12 +00:00
"execution_count": 5,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-09 14:30:12 +00:00
"height": 1000
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-09 14:30:12 +00:00
"outputId": "598d90d3-51c7-4f4d-e12c-019924a7817d"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
"data": {
2022-07-09 14:35:17 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABQAAAAUACAIAAACXhmigAAEAAElEQVR4nGz9ebBtfXrXh32e5zettfbeZ7r3vu99px7VraZbSKKlAiwDYpAZHBxwwFhOKlRcTihTNsRQOOUQ47gyVpzhn7hiDCTYmLjClBgFCFBIBBlJDUggCamlVks9v/1Odzj3nL33Wus3Pflj71ct4aw6dereM6yzz96/8/s93+c7PDL97/6j47szq6Pu2UwcXpA7U6A3rCMBUazRcGNo64qDtTJuWFbSSFswMPARRqaRkHjxJZ4p/+X/CRq/eH34d/CrfxUPX2E78vQ5cUM9gMMq3YEjOFqEpwTHX/gxnv7Fb3wvL6E3/Ld/B195mxB59YKeWY64jhhOEMMLptRGTGhHDFXMaIZTRFhXRkevOOiBapgQDUADLQMI5I73SKOuNBgiZgDdqAZKz2hFAKN5REFQWBdCpBzQhkQWB4avSEACVrGOdoqBQKNnqhEUMcRRK00IgeXAoJQKQnGYYEqCVpFAPuA9KXJfwdCGGC7ROgKt4qAAHllpDe9AEKV1uuAc65HB00Cge1ykdfz7928L2pFINjBcQz0SsI5BL6jQBO9pC6aogiEO65jQFMsk6EpQmqMpImjHOh2s4AwLVId1QsMUCdBRT1lwSgMEyVTDBwRMEaEZGqhHNpFcSYGqSEQM7XRBBIWWMUUdqkh1RkNVfTcT56z2lEa8b04eXH7wnW9/nXcKzx1J2Q48fYoarlEWlpVu59c9OXpjN7F/gYEfODbGDa5yf+DymrKyrEyOpXB3z/WWZ8/JncsRL3Sld4phwvKC60QulErccqyIIzqOMynQMuuCGauwTUhjWZBIFFpGYC0cVi62rHuCx3tConRyoTXmA7tELzSHKk1pRoIY8IlqOONwZLOjF0ojBYLSGypIRxtBSYYaUhgdW2VK0PDKGEhXhAF/Te3IQFXsAQfDb7CZ1YgPOR6oic1E7awLo2M9Ms9cJialdJaMVyLMGTN8YO24xATrgWKIkjaIUMEqw47aORRCZbtlXhlH9oXtSIedcijMM1KRDsJ6T/TMe8bEvCcOvDiQAuuCeEqlV6JjrjhlySTPshCEVlkb7UhuDEIxVCmN2kgjVHKhNlYIDsvcHRkjPVMMKRwWxCMrU8B1YuDpATxdsI6HTWLZUxayoyrTgDfmzLijNlohdBq82DNN7G8pncuEgEVaoyrqOD7nMrCudMdeCCNr4SKixpyZIne3TBe89YxdRAsV5oVxg2VKZwCgGjnhPsGf+7/y2Se0iRBJ8OIWr9QVKyQHQumYEB3rynZLXuid6iAhRp/xgouUjHW8UDJrxTwh0xvTcN4h1w6eMfLilstEK+xXLnbcLyRFCgRaYxzYH3hpR+28+xUe73j203z1KXQefoI28vq3szyhPqTesR6YAofnVE/0eOWwZxNYC/PKINwXLi95Z2ZrHBZQNgO3R47wcCQ3Bs/9gTgyr+TGg5HbmX1j42mVWnkwsGQsoI0VmrB/yqPI8cC45WmmJ+qGsLBptM7xSFm4vua2oEp0VPCOQXiy55WHeOOdF9xck9+lLzzc0ToEaqN7hsSLd3h8wXEBh428qMTA0DChG7YicLHlnScQCJ1m+EiuFGGjLJ2XjM+8yZf+NrzzjSP01/xe6Ow+QXe4grvAz3x9z4OBY6dfEGbmF6TIGGie1qlKEPI9V1vujiCsjhLwgrsnDpSOdtYjLkKlweope6SwHaidCtXRPEFZ79lF7vbUyrjlXgiRqbIKIaGVueIaUXlxJO1Y7qmNq4lW6AETmscC+RmT4T1zoUZsxAekwIoLrAulsZ140UgJXViMYlxElpUKoaCe+Z7dFfsHv/3P/sav/8JhU/0iZUhDnm8tuCl663nOZSl958PoCAHUW3elZNdAZW5dzEXEs6DNNJoIqHSteFPaetxEEcuZhoVtHKxkCdZdVAJtjuIqbd9Izm21al+LcwWngqpvHUEwWrm73LjWS4UUR7p2UwlSTQJ4qb33JpLXFpWd62pziIPRF2siTmtxwVVaoYsWZ/tYmmpzMW4sPRjjIHf1+N6nHoaPfPLRO1/5oY9+5JUf/vM//Uf++E/+/PsL53Vo8NYvqcZ+7D/+loffdOF55cuf+/xbT98py7GV7q/cS69v/saffuvnPsfj38Cnv/XiO775mzq2vz123nv9k9/09ad32W3ee6f9F3/wx38cfs0nkCPf9xX+m9cj+O98J8/e4Ue/yqc/wu/73vTxB/7L7xwevOwvnXt481KcPmTX36GXH7gr2/XtH/3i3/0Lz493y9Heqrr78Cu/84/+oYcvffKdZ5vbfbh7/mJ/fPd43zMFsZKrc0c36vM5Ly9anFIxOSwtr8XFy0Vabr6bn23aN8vNB3El965WxeUmJSOmjVDp63zcJUerB+tl9YMmFd9drRKCunU9OufN6f3cRu92ntbzITefUusYzjBwS7ac593G5drmnDdp07trBJR1XjeDb+sBkYp6GaIgfcXpWtU7XEh3x9Ib6piXrl3F9WySwfuYi6bg9muNfljXo3OhthJD8sg0xXGj6nnp+qOfsc/wxTtywEMS7u/QjuuUBe8wxZTSUaFXpoG80CoaWQ3v0ULvpIllRTxSqZXaCErJVNhM1BWN1EwDjawHEpRCjJSO+XMZOR+YInmBhjm6ByNANaYNxxnvsEozamHwHA40YTtiDfWUgjlC5HjL5Oid0tBENnzCK7kQA8sBlKTsK8FDxoTuiZ7ecI6SaQaQDzjwAR+pQgETxpH5jimSMwZuIHfUYZkYWCoUuhEmkhEj3dH3pPLdN2/9+pfs9/2eT7/0za/r+uU//Qf+5P/hz735/Jcu+o99B9OHefl1tiOHzuNXmGfuMxcjh0xItIx0cmGb6EKvLJUYqTOTJ68ET3PM4EZSBYjKcsB5VOhKaWhGhAybiFMMrBJH9itJsYZXzOhCO+EjT6uEX3xaPN3InahIpxjB0wpOcY45EwJlZUh4RT05g7IKCaQTlLpwPBJ359ILYKUVRo8qx5kQWSsnVDkk8kJUFsMc6nEFjDCwrkjACrljRoC2UJUxUSsaKBlxmGPJbIT5SOu4EVWcR5XqOb7mj4eEVXTm4oJ5xo1y5XeY9TI7N/rI3Jp18dZzqzKUvCclJ9o2lygXi97llUHpB5LDTdjK9ID98ZehX+C7P4TbcbkwGxdb1oZPSAClFbyjZLThrgnGVnn6S7/5Xa5e4/iEXWYSKHjHgyuWI62h4KEJquNmEGHuxTtpZio4z1oyueHTVLQ6l3WltsHvRvzzfq8+aO7mdTLuKew2MYzDfT6E2IYs1fkm4nrWwuTJUJWQxMS0qtNQdW2d3tE0lTR7M7kl2xQuYum37gBCMbyk6lbNTAPTlhczK7gKkWbQUcElshBGrBAS1vDpjNX7CgPdkzz9wPEF40vaQy/3YKwgQlMC6MK4k3Vn7Tn+SPOgSKeDBtZGGDDwDilo4thAKBmJlIFgyD1V8Nc04ABKlTPadw0Pw45bwwlupgeaIoar4MFBpLdzM8IM27BmYqYKtkHB9uBEL4bO3A64SO2oDbMsIRBm4kW4DYUDw0JTzOGMVs6w3I+UigvUTNhtWziuuSej4YidZg2Sih8vfMzLodK6ipQ2+GhNpjgImkvvKd3n97j/ODFxNTNcsN9z9YirkXLk9j3iiEBI5BVTaiEMjJ5BWI2XHrI25jtev+LuiMCjK/pK3PDhj/P217gJWEeMpgSjVC5Gjkc2NwRPMmrBT2zAjGPlaoN35BeMgWPl9ceo8u67DBFVfEcjy8rukjcuGDy3t+yfkw2DaUNsuIBdYQsmiFKMq4eUSu8cjujIKDgj3RAU12ge76GfcZcaCUZj23AdV7gcGCeS0Q1XGQMbB46+sI7kCZuokXFANhwTuw1rxwV216DsVx49hsxx5tVIz+Q7EiSPFyyzgXUhRmoljqiSdzQhBoaECHiygMMZL3vmleAZDPE8nmhAoxeGSp4pGYzezhtc3FALY8ILDwaKMF4SlOORLpjRG5cjtTCvhCt2iXzH3ZH7xCSMQoNaUeMq0QtUZCENv
2022-06-28 00:58:17 +00:00
"text/plain": [
2022-07-09 14:30:12 +00:00
"<PIL.Image.Image image mode=RGB size=1280x1280 at 0x7FB4B7AF7D90>"
2022-07-09 14:35:17 +00:00
]
2022-06-28 00:58:17 +00:00
},
2022-07-09 14:35:17 +00:00
"metadata": {},
"output_type": "display_data"
2022-06-29 17:55:23 +00:00
},
{
2022-07-09 14:30:12 +00:00
"name": "stdout",
2022-07-09 14:35:17 +00:00
"output_type": "stream",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-09 14:30:12 +00:00
"CPU times: user 1min 23s, sys: 466 ms, total: 1min 24s\n",
"Wall time: 1min 24s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-05 13:37:52 +00:00
"intermediate_outputs = True #@param {type:\"boolean\"}\n",
2022-07-09 14:30:12 +00:00
"grid_size = 5 #@param {type:\"integer\"}\n",
"supercondition_factor = 16 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n",
2022-07-05 21:22:20 +00:00
"top_k = 64 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\", \"128\", \"256\", \"512\", \"1024\"] {type:\"raw\"}\n",
2022-07-05 13:37:52 +00:00
"log2_mid_count = 3 if intermediate_outputs else 0\n",
2022-07-04 20:06:14 +00:00
"\n",
"image_stream = model.generate_image_stream(\n",
2022-07-05 01:25:38 +00:00
" text=text,\n",
2022-07-05 13:37:52 +00:00
" seed=-1,\n",
2022-07-05 01:25:38 +00:00
" grid_size=grid_size,\n",
2022-07-05 13:37:52 +00:00
" log2_mid_count=log2_mid_count,\n",
2022-07-05 21:22:20 +00:00
" log2_k=int(log2(top_k)),\n",
2022-07-05 01:25:38 +00:00
" log2_supercondition_factor=log2(supercondition_factor)\n",
")\n",
"\n",
2022-07-05 02:52:00 +00:00
"is_first = True\n",
"for image in image_stream:\n",
2022-07-05 02:52:00 +00:00
" display_image = display if is_first else update_display\n",
" display_image(image, display_id=1)\n",
" is_first = False"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-07-09 14:35:17 +00:00
"include_colab_link": true,
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-09 14:35:17 +00:00
"provenance": []
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-09 14:35:17 +00:00
}