min-dalle-test/min_dalle.ipynb

210 lines
1.4 MiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-03 22:40:27 +00:00
"colab_type": "text",
"id": "view-in-github"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-03 19:35:33 +00:00
"execution_count": 4,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-03 19:35:33 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
2022-07-03 22:40:27 +00:00
},
"id": "ix_xt4X1_6F4",
"outputId": "9c4219ee-bf34-4d9f-b234-ed3fa2e98cef"
2022-06-28 00:58:17 +00:00
},
2022-07-03 19:35:33 +00:00
"outputs": [
{
"name": "stdout",
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-03 19:35:33 +00:00
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: min-dalle in /usr/local/lib/python3.7/dist-packages (0.2.10)\n",
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.11.0+cu113)\n",
"Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (4.1.1)\n",
"Sun Jul 3 19:28:49 2022 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 32C P0 26W / 250W | 0MiB / 16280MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-03 19:35:33 +00:00
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-03 19:35:33 +00:00
"outputId": "5d3b2f09-f3c7-4759-dc7f-fcf7a6366e54"
},
"outputs": [
{
2022-07-03 19:35:33 +00:00
"name": "stdout",
2022-07-03 22:40:27 +00:00
"output_type": "stream",
"text": [
"initializing MinDalle\n",
2022-07-02 13:31:20 +00:00
"downloading tokenizer params\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
2022-07-02 13:31:20 +00:00
"downloading encoder params\n",
"initializing DalleBartEncoder\n",
2022-07-02 13:31:20 +00:00
"downloading decoder params\n",
"initializing DalleBartDecoder\n",
2022-07-02 13:31:20 +00:00
"downloading detokenizer params\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
2022-07-04 13:29:54 +00:00
"Note: a 3x3 grid will work if you were allocated a P100 or T4 GPU"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-03 19:35:33 +00:00
"execution_count": 8,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-03 19:35:33 +00:00
"height": 972
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-03 19:35:33 +00:00
"outputId": "228430d8-9b29-46fc-85c6-b20b04f1d88d"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-03 19:35:33 +00:00
"name": "stdout",
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-06-28 00:58:17 +00:00
"text": [
"tokenizing text\n",
2022-07-03 19:35:33 +00:00
"['Ġdali']\n",
"['Ġpainting']\n",
2022-07-01 01:23:15 +00:00
"['Ġof']\n",
2022-07-03 19:35:33 +00:00
"['Ġwal', 'le']\n",
"text tokens [0, 21853, 1545, 111, 563, 92, 2]\n",
2022-06-28 00:58:17 +00:00
"encoding text tokens\n",
"sampling image tokens\n",
"detokenizing image\n"
]
},
{
"data": {
2022-07-03 22:40:27 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwAAAAMACAIAAAAc45fZAAEAAElEQVR4nIy9ebwt2VkW/LzDWlW1h3Pu1Pf2PKSTbjJPJGTCSAiJBAQDCKhhVEQkihMInyh+n4h+oIioyCggkSlAGAICgSQkJIHMZE46Sc99u+/tO5x7zt67aq31vu/3x96dhIB+1u/+7u/sqjpV+5xdZ9Wznqno7//P3xzbvFy6UhIJY1yvHbSYpdKqaWrFNCipOmMap71Z715X1szSIEKtiUZhEcpkY5dyhR/Vxk4dWi++sYAqBxAMkIc3i3Fandzr2lQKMSOTUYBDo1kkAlkRJdPYrCOHz6W421FTzcxMHmG1UgQxT2U1dLIZa6XotbPmTOHhSdnbGDHVWmd5KBYswQmjw1gJNrWNsjQrhyWyyl7SsBptGlK01WEmvub4gqGTHZ0aHvbpXsXRmZOL5z77uWnEYjglMi55YXFwdHC0aVdKEOxSHe/vLhx4qdE2SSM269mNvGgXuwJv0OPQK7iwQrePzVnc/wDO3o2PXsT7PoZX/SEeK/jpn3n86WvuMqk084JGgAkX12Kz1bQ4OPC27s5fnp/7+P2vfdfw7rdcvu8KMciQQUXz2GTorz7z7C9+zlP/whMf9ZRnXLlo955vb/3YpXFx6vDwqEKyhNaVmUsphJj37MytlGi+N3RE2qZVP+/HzWZ/1lnbBNLUynzoVpePFl0nXKtTSIKHN4LkYHhZdwihWuClZfXsFshRqWMobNOlXNk3YyhjCAOVKxubz2cUQBAhAtSMSt2cWOZxU9alzYelGztnZ2vVs1CiVluphlmek1XQ5BCLxGRKQuxTqYvZsD48CgfQWHJVacYUAkIF13GaZbKo62mtkdE8D73M2Hg28OJNl8/hgQ1qQiZIYLOGBCgQBUwAIYDqSAoECAAARzAaoApqAMAZbiBCGKzBA8zwBicMGa2AEqLBCcEoIwZGrVBFMIIQBGJYQUrwijIh92gOTYABABQEuO3+bwYStAkgdAkRgMAagsGK8QiLjFrQGriDE4ih9MndVFAnuKMZNGGawIxOAYIHzEEMUWyOMM8oExBwBhKIoIADmmAVRPAGc4SDCXVCA5YzmAEMbwgGFHWNuWKcEAHqYAQWJIIHiFELuoxpAwiqQwAzOGPo4AYAYWgAC2xCYliFECrgACsYIAcCzZAEIqiGICAAAAEVAIgAMWL7hYMAGKYCIniBAZlhDk0IgiZEAyu8oRnGitQjDwBhM4EIJACBABW4QQhggNCAlGANEegE1kCKCJAggGhQQhgIMAYxwFACAixwQ2wvLcAcTLAKZ3QK3x7H0AIQ+AQxABCFAawA4AF3ZEFtCANnuIMJArigBZhgBmK4QxRmkBmmk9d91zff/8HDDpiogxhqxVgw71AKFCgGEjAhAAdmA8qEZggFFAHwBGVoj2mkJGEGczSHCcoGBOx3aAZ84ipllCMcU1zZIGcEYyJAMANqQ8q4coC+AwMbICdMKyjQgF4RBjCqQRVnNzjVoW7ACjdooBCIgAY2UMAnrDZIc0RDpzhc4dgAc3DCoOiJa7nuWHvm409O73nHq//RNwJ4+dcML37GC//99/z26x+oAL7ln3/PT3zvfzqazuJPLfO/+IJnPuczv+AZn/3ENJ+prR+8+NBikc6df8fU7vxXf+uXLwIAvvVf/cAXv+QrrumvLM7srw9Xb37Dr/yNr/nWl335Z7ziFz+0B1zB/9HyrL/+1N/4H9/GD9zxJZ/xL/7g8JPrE1ABBm6Z42OrT67/+9/wpOsf/9RTV996fDHrunxwuR1dHPeOX3316auH0zhx9fA//vMvvPInX/Gvf+TlD70rf/s/+8lVu7c9Mr58yQ///K+/82j5WV9q87opLY2rTmSqUJajVpfzObXxYDXuL/bRqlfsz1VYpjqCGQgh2ZQawj46e/S9w001t4A3dlYVHK43Q9dH20D4qIRwxx5JWqgQlNwma5rz4Vhasx4cdfLA3v7M3Ji1NBcSZj7crI4f664croe+bxV9ShEIik11yWq1Tmasas3nmZKyOxGl4iFC5r4evet6j3GRde6DPmwnbXUwzzyxbOrIuTuRu94mJr7ikUUXoSVoXaZ51y/X+WGe1nmaU3fa9+agc7IyIDUQUzfWI2ljjiVmJ8flFTvyPIkruwtbC3Nwgc+GToq5Jof3rkvpRmtXvCZK2Z2Uc6sPS/XFcIZPlINLGxqlQ2LicAtXCJSnOiVx9jEnDYKC5jKjiMolzMllwqbrNDz20iKLTbgC5cq+QXVmAGng2UZF8yhGEoapg+bZLKVu3cWZYTg165c4dmzvidctcfPp/sb9qzqI48HN2M7e+6G77vzwQ+fuXx0dToeXNqtWGsxAhDMnsXQsFN3deNLjcWoOVOA8omI/ox6BBScX8NO4vMLU37jBpWe/5PNPnLifu6REUtcLgjOsZaSu6V7j/e6qYw47OD+dP3nrzY9Kd33W4+95wB784Pl7r6xKPVg56liL3f3BV9xz9jdf8xkveNbz/uqX5NXmcYvurnau067sCTX3kbAJGdIiq9WptqKg2WIYeHZUDk/NF+upHE/L1LwZNyqd6nqa8nymQeSbjvLGWhfacSrul+soLfZaN8l0hVcdd9fbnlo8hKNVQEojssHW57RNSU8sjqcLZVVaN0viopWUC9gm8IE7s7Wj0UNz1y+pU6QjLyNikfqh4lJbb2LV5X5fun3fu+QXNuHsLoDCSnIRTHU0TrXWWRoopLhpYBG+gV1oU5dUqhz65FkGnl2DvTW3ozaue22xxjiDBLQhKdYTFjMooRYYEMAWpwkDQG3oMsIAhxGWA4xQJiSCGdgBghJCMMzhDdZABm87+GIEFRDACco70BOEpKgGBDhBFGFIGSkjC0gwTru7GsfuflkCswUEmBLQ4A6kLRxAUtSKod9hspygHSDwQClQhQM5AYEuYWrYO4ZWwAJyhCMEXsEMEZSCnAFFFrQCTugGeKBMyAoPiIAIiVAcsyVagXTgLRjK4EAl5O2NtgcROkUEOIEV5igVfY/S0GUQ0HWogf1jGDdIADnCwAkcaIRlj1bRAsJgBRxZkDIiEA3NIIzsYEIEIiH3sApvQIAIxABBFe5oDUxwIApmHcxQE+Z591kTgQITgQKLDA+sJ7QJXQ8WjBNEMQwwAhNaQ05oDUlAhOLYm6EZWCAAAkhAwAk5oRVQh0RwAxw5wwFS1IIuwRwcEAUHCrBYYLOBGqIhDJR3GC4nqKAJULGdVSqBBeMnDqJIDHP0s92lO07oMjQQFcQYHXkAEzSBgL2D85eAGQcZhLAhLJT38wI4GotXRy/zrqserXlS3Sc/Nwa4cZ6dGtSaH7WooC6M+1hEO9AYjbC3vCbR2UMkL3ADB0ubDGDAJ+znjm3aS0CgRb/fjWPFuiIlJbMTiaj4keDE3vG97tJ5w1SQATUYYIYsMMNt+wjHpu2AVBZdJgWNqxHBSAoDNKMRJka/wLGTODrEiR40YdWwFj++f2HZP3hu84JT8moAwHf9rb/3pv/55tc/UK+/jV79lvv/67/9wac/9ZZv/Affe6y/8s43vvOnfuL3P3r5LmD1+te+7vWvfV3/vcNXv+wbvuwbvvzWGx9z4e47HnXyiZfOn/7n3zf7h9/6MwB+8ad/4fOe+/STNz4pZ5y+9tilxz7mlgX+r5f/8Ct+8S+WT4E4L3nGid9628X/FQD6o5991/1fc/HNb3r3p6IfABUA4MA/+e5v+qZ/+F8/sf7db//YbZ/1OeVws5Esnudp76D6seNXX33qOpXN/rh45mNueHOXuNx+6vp6ZnZuXF4tdvGjDxYF/8q/fcWjXvrsRz3qwnsesmEpoeQI8ToMvXJ/fOCHLuLkctib83rKs31KrIebTUrSqQjF6C0F1h7LU/2e0NGVIyVhCWk+zGRCHByOyyGdVJwTGafNIs9PLLNbHBS04A4hyRdBG2rDLA3D3r7GxcPNAA93MLOECGnizdjmQ+awvcUgSpGQE185mppjnnWW5DBFm3g+zOZzIdSLl8rQpyEJuXvQ5UPKokP2TrtJMOFQK6KNG80M8
2022-06-28 00:58:17 +00:00
"text/plain": [
2022-07-03 19:35:33 +00:00
"<PIL.Image.Image image mode=RGB size=768x768 at 0x7F086EC34C10>"
2022-07-03 22:40:27 +00:00
]
2022-06-28 00:58:17 +00:00
},
2022-07-03 22:40:27 +00:00
"metadata": {},
"output_type": "display_data"
2022-06-29 17:55:23 +00:00
},
{
2022-07-03 19:35:33 +00:00
"name": "stdout",
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-03 19:35:33 +00:00
"CPU times: user 35 s, sys: 0 ns, total: 35 s\n",
"Wall time: 34.7 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
"grid_size = 2 #@param {type:\"integer\"}\n",
2022-07-04 12:05:55 +00:00
"seed = -1 #@param {type:\"integer\"}\n",
2022-06-28 15:05:59 +00:00
"\n",
"display(model.generate_image(text, seed, grid_size))"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-07-03 22:40:27 +00:00
"include_colab_link": true,
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-03 22:40:27 +00:00
"provenance": []
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-03 22:40:27 +00:00
}