min-dalle-test/min_dalle.ipynb

216 lines
1.4 MiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-04 14:21:12 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-05 01:25:38 +00:00
"execution_count": 3,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-04 20:30:39 +00:00
"id": "ix_xt4X1_6F4",
2022-07-05 01:25:38 +00:00
"outputId": "4451164e-86cb-43d8-d0a2-a6ebef7d3938",
2022-07-04 20:30:39 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
}
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:30:39 +00:00
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2022-07-05 01:25:38 +00:00
"Requirement already satisfied: min-dalle in /usr/local/lib/python3.7/dist-packages (0.2.25)\n",
2022-07-04 22:40:32 +00:00
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.21.6)\n",
2022-07-05 01:25:38 +00:00
"Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (4.1.1)\n",
2022-07-04 22:40:32 +00:00
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (2.23.0)\n",
2022-07-05 01:25:38 +00:00
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.11.0+cu113)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (3.0.4)\n",
2022-07-04 22:40:32 +00:00
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2.10)\n",
2022-07-05 00:56:45 +00:00
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2022.6.15)\n",
2022-07-05 01:25:38 +00:00
"Tue Jul 5 01:08:31 2022 \n",
2022-07-04 20:30:39 +00:00
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
2022-07-05 01:25:38 +00:00
"| N/A 36C P0 27W / 250W | 0MiB / 16280MiB | 0% Default |\n",
2022-07-04 20:30:39 +00:00
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-05 00:56:45 +00:00
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-05 01:25:38 +00:00
"outputId": "f79d4490-cf78-4642-e1f1-5b74c26fb2be"
},
"outputs": [
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
"text": [
"initializing MinDalle\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
"initializing DalleBartEncoder\n",
"initializing DalleBartDecoder\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-04 20:06:14 +00:00
"from PIL import Image\n",
"from IPython.display import update_display\n",
"import numpy\n",
2022-07-04 21:27:02 +00:00
"from math import log2\n",
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
"Note: reduce the grid size if you run out of GPU memory. 4x4 has been tested to work on T4 and P100 (with intermediate_image_count = 1)"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-05 01:25:38 +00:00
"execution_count": 3,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-05 00:56:45 +00:00
"height": 819
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-05 01:25:38 +00:00
"outputId": "cf14c540-6fad-4a7f-fc93-2a9a5553e02d"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-04 14:21:12 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
2022-07-05 01:25:38 +00:00
"<PIL.Image.Image image mode=RGB size=768x768 at 0x7F6BCA36FB10>"
2022-07-04 14:21:12 +00:00
],
2022-07-05 01:25:38 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwAAAAMACAIAAAAc45fZAAEAAElEQVR4nGz9e7Ct35oeBD3P877jm3OtvX+Xc06f7j7dffqSpE3SAQMkARIuDUQoApIEU4D/QClo1FKkiBZVYJVWWVoWloolWFKFxARFJUZBRUVCsIKpSjoJ0Gk75gLpdDrdndN9+vS5/PZea85vjPd9/GPM3WiV+7+991pzfWvO7xvjHc+V+NIPo54w3mg0dBk4wQy6zIhngpdBK0LUkKtAyqscJFPqlj1BR0T3kgZsMygxQFYt1v1WHAQzI0eIbkopi7YIEzASmJlsiwIVwGhCAZPVTRC1QDYVwQit1SAhmm1EddOoBfTKQwiKAUtKgxbI7naXCTeIptschN0N0l2TGue8i3Cp10JSGaRXwd1unMuKOi55e5nn/UQRLVgcJtVt2j5PKODCMvPq84UBBrsbFtaCASV4O56u592Yd5SQT1gLV4LCAqJw3hFX0EChDqwXHEAe6AVd0IQNBeLEGFjGvGEJSNAQEInZOIhu6EBPmPCB+YI8kQMN6EARTQShhZGownmiB+IAiEOIgdkIoCc04AIMHjhfkcYIsELpRjfhgO6X61hn1/1EJ5gox5VlYkGxei0ogQYgHL1OJjWCBFq9CpCZ5JnJXkDXujc4YI5nGa5yqmqtRsAm7Q50gVBGd8OJNgCYiIAWOHCeyAR4XJ7P87PrU9y+lv+t3//rf/6bjZfbe89U9b0QOK7HvbosQCpfj4MR717ePWVkaKbev68D+aTIC5YUlQPrcmjR33qZQ3kNpAxdGnY5NArdgBfhM2JlHE22QxCkYs/7/c2RwpqB19u6YISpREWEc6iJZuhdrXPWc8QV5YhJyiSju6Vg6H5793SV7UlURXQE5fR0Hw6uM1It3KdRfaGPUMVYZS9nDgurygtHcGnel9hOHQWCrK6er0/JoE80CqkkuNgn1EavU0ZejlfwpTzA7j4uFwivq0o8q5qx3FAUAbclKlaXxWdldZM66z6IhIN8O47GbLKAEFkAfZXe5rpQ6du1ehgxBMglk3QFQPrs7NXiQLLYt2+8fzPGuPZZfZ/TpWMcYxzds6uifIzs6GXUrRORB8ro5UCPNNO3c625DnpEkwOi27YgEO5irTNHwb7XSl5IMbLZa97DuEQzb7fleonoiMtyDHXSFQTTL3034Nf1FGBiNWkRdNskyPP87HqNtc6X8wxf6TSzo+dqFQ7NwsuUbu8yer35aNXJ1/e6TXNw1ap1j8tHk0drfP7NR//cH3j7zV9cR+c8RBi1GhbpQATXbNkhglq1xgi0GbIpy6aiGzAYhLsBl0EiQcIGGdFeQlQ3qCoIJTVJEjBJGQBYXSFxvwogsBsRNmAraBuG2wAcINGgINkNsw2De2WXiiRFN0naMFDdIdI2YUtkG6JNwoTtbgq1moJMuJug1C4g2gboJryoBiAJBqk2QRgGQLDdgAB4v1WgwZAM2G6w3fT+ZlMC7EYbBKthr5HuKkhCAKwGCNMAabSLwS6QEAywzRCBhlVukFWmK8LuZggQIZsmGiZF224Gu0k6SBltKNRdYnbv1wGxItqPnxEE2zRgev/87mLQhQgJbcjNCBptq7tN9oJYoSKIx57/eN/KLYV7/wh0ISS5SXSLAbthtW3vz3eGLIIR8H5faKANCmytXhHR7QgZgPXhf9tAlY/I8hlK9Mf86K/+zx25OJedDZBmOPfEkJYPtIAQGyZBswhQMiQOuElD04CddtFhErJNgRYzwJSlsFw0DLbtMgjAMAWChLTvYxOkZcjifjwgHyEHCxEKrwXQahImbRswpAgJ1ef+Cxo27Oh9P5KrK1hzWRxhNpoUiWYF5dAiUpHUXOfqO9ptucu2EZDLpgwyhIyEw2Xvx7+6XQdjYa5uObjgKuvuXquCnsEAomUYUp9rjbB0UTxz4ezVjbRbU2TL9+KIYyDaZ/N9AYUDyzCB5LDtkbqv8yKUk3G9MG/nWeWAmwWL6QW49aSDaOerYeKo+wykHR1tU1J1ZbQR0qFmuckQjX3nJhdMxEUJVONGNDDWWm4DCbUBioAzkDHc6XJ1PWZAVwAFFylEIroXNOGGR/UMDoCm24jgWZ1RqSEdPr28BLnbXGiYrkYgDdrLKBK27HbTJNjVZGjOOpQmRozIuB7X8nk85w//13/t+e7dJ7neH8f9fvs4xjEEn12Y1kiFVcT9XteIT3O8YH2zzyc+fxveRuEb8b4wruDI0lrf4FyXp88dT3p/Z/Dmc3C4HOIJn15DV9U8RrfVIIzAuFcXfFBvyFPz1TN9+TSeo/lN3qw8mhGNnjfiW9FP8fTFPu643XGTD7QpGAbjviro58SL5yRzjY/H83S9rxXMa5Nag/ObuL/38Wk+Hec9o9+XCQ4C1M2xMN8iUfMW51njqnHIr8YJzeor5ufAk+ub6CvG1aPpV5+vJpms9al0l38u8L7Hd8aVVYxaWGe7x/FZGXEIpHCHs1uMy4jX7gYOBKAgvrnOS2aANJ+TE+XBWTVGgqHZxzzfcj3N+xX3T3kZ3UvrTl10udd5icB8TfLmscCrx7u+3cyPcHnb65Xvbw1IT3EkoguQ15of8Wn1y4x1q3XlR9fKu14LSgbVydsLz1f7c9e347bu/dJhmtUdwVpuFxzZ80jcVikjcITzVvPsleaTLq3XVS/N+mR8x5h6x3fLoYKiAvOl7zc5enyKZ/N26p0r8TgSuTtvq+XzoKe75Ce9PWq8tt93yXzisO63/sa7Wm/zSx/x+Mbrz5IhPqGEhVqztabyVv2LPN6++l/4w99b93dPQjPPtUZQQ12FtdqQqD0XtVMaGSfa3WpdcrB1YlKkYVQAyLi3L5FZOF1mP6YbaXW3PSK62y4JomiSWOU2BOTQ7IZbUEYGdNZUEOXHS0Xcqp/GwdmL1bBsgibLbjsV3Q22iD2VAGjv8QKZmrC7wgoGGoUiAdt7wAidjQSGdbqAFgiiwYbbkNTdkkmIsiGo7AYAhDRr0SCUmTJnF0V2m4DNHPfuobxQr7Xohg2C5N5gbMAdAcMC3RbUjQb4+LISAUWBR6hmtb0PgQD3gLhnKaIz2TAAmVKs6YYp6sPrKPJsj4gw5lqk905soO3HpcEKgHtTJo0q9B5hpOVKChIihjTvyzRsAbb9+L08IrubLO4nvyxGFcoFMqTlChLSgq4Rda6G7QqzaUKr20ZKEWos7nm0ENKcjRDsyJzVAShj70Fn7U+Q7ZZdVlCRCMSZxBH56Xd8j/yNcNgC2nAGpGF0ho0hBCGiSe1JGYDZZSqSaBrmsqItoQg19yyucFAQvIecYGy4xuwG2nq8HGW6u4Iqas+I2NMlsadfwLBXWXBKeb0W0Kju2kf9hhvBfT7gpYimKECAwf06zcQArLRA2rIBAgiG9xcaQABxDCXTbHfs12nIqG4EA/uKgD2Vsw0SBvYdXDW6BAXkdvV9wxIpg2yg2VUW+ASJKIAIXX3YQFCwXWuaPqSQ0Fjr0rwKTuR4bA0CauMcFyjCBVRB4PP1qS2iQHdB2iczJel296VJMd88bfiBYHUjHCQILzsYwn6wh4S9SkgsmGAq3N2+tLsbTxdBNMNYbQgSH6uPoD3EtoN+nIJsNxwIKlw2JwC0JDRoco/sMC9kwMsIUBfYptUw0W4bXXZQDbhgFGBDBAw29qKCbjw/pWgzUX3EWF3H8dH718G3g5+9rJblay7S92WlFhvqImebGNe37PN+W9XB69CBXuteh49AcNEB17rwiovgo+4NnmuNZHlZbkajRsSRJwr3czFHo0i15hg6APQ5QcODmfT02cIlsN/yZp9VYH7uSPikWWdpoD29byo23U9v5F7dDCIi8tDZd8hvaPQpZ9inTccnoWeuCr6UWxZw23gsfIkAysFzxlAU6hU4zWY/XxQL56wCasSLsVyNNjFSt7Ui+BnqDiAv19C5prtPzuV1sqruS/EkD/TZq9xvlWbcwAlfIgLR5VnrAn+qXOd8E5nV30JR+phydy1eVJ9pYZxe756HVIPJ6bZi5muj393vR2oNv6ucrRNjxrwtfNTrrPUe716k53y6e
2022-06-28 00:58:17 +00:00
},
2022-07-04 14:21:12 +00:00
"metadata": {}
2022-06-29 17:55:23 +00:00
},
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-05 01:25:38 +00:00
"CPU times: user 34.4 s, sys: 558 ms, total: 34.9 s\n",
"Wall time: 35.4 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-04 20:06:14 +00:00
"grid_size = 3 #@param {type:\"integer\"}\n",
2022-07-04 12:05:55 +00:00
"seed = -1 #@param {type:\"integer\"}\n",
"intermediate_image_count = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\"] {type:\"raw\"}\n",
2022-07-05 01:25:38 +00:00
"supercondition_factor = 8 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n",
2022-07-04 20:06:14 +00:00
"\n",
"image_stream = model.generate_image_stream(\n",
2022-07-05 01:25:38 +00:00
" text=text,\n",
" seed=seed,\n",
" grid_size=grid_size,\n",
" log2_mid_count=log2(intermediate_image_count),\n",
" log2_supercondition_factor=log2(supercondition_factor)\n",
")\n",
"\n",
2022-07-05 00:56:45 +00:00
"image_shape = [256 * grid_size] * 2 + [3]\n",
2022-07-05 00:08:31 +00:00
"zero_image = numpy.zeros(image_shape, dtype=numpy.uint8)\n",
"display(Image.fromarray(zero_image), display_id=1)\n",
"\n",
"for image in image_stream:\n",
" update_display(image, display_id=1)"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-04 14:21:12 +00:00
"provenance": [],
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-04 14:21:12 +00:00
}