2022-06-28 00:58:17 +00:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {
|
2022-06-30 15:25:24 +00:00
|
|
|
"colab_type": "text",
|
|
|
|
"id": "view-in-github"
|
2022-06-28 00:58:17 +00:00
|
|
|
},
|
|
|
|
"source": [
|
|
|
|
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {
|
|
|
|
"id": "3WL-G_f2_ld8"
|
|
|
|
},
|
|
|
|
"source": [
|
|
|
|
"# min(DALL·E)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {
|
|
|
|
"id": "Zl_ZFisFApeh"
|
|
|
|
},
|
|
|
|
"source": [
|
2022-06-29 13:43:42 +00:00
|
|
|
"### Download models and install dependencies"
|
2022-06-28 00:58:17 +00:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2022-06-30 15:25:24 +00:00
|
|
|
"execution_count": null,
|
2022-06-28 00:58:17 +00:00
|
|
|
"metadata": {
|
2022-06-30 14:02:08 +00:00
|
|
|
"cellView": "code",
|
|
|
|
"colab": {
|
|
|
|
"base_uri": "https://localhost:8080/"
|
|
|
|
},
|
2022-06-30 15:25:24 +00:00
|
|
|
"id": "ix_xt4X1_6F4",
|
2022-06-30 14:02:08 +00:00
|
|
|
"outputId": "6a676bd4-782c-4c47-8b09-3895b4314c6d"
|
2022-06-28 00:58:17 +00:00
|
|
|
},
|
2022-06-30 15:25:24 +00:00
|
|
|
"outputs": [],
|
2022-06-28 00:58:17 +00:00
|
|
|
"source": [
|
2022-06-29 20:32:51 +00:00
|
|
|
"%%shell\n",
|
|
|
|
"\n",
|
|
|
|
"git clone https://github.com/kuprel/min-dalle\n",
|
|
|
|
"mkdir -p /content/min-dalle/pretrained/vqgan/\n",
|
|
|
|
"curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output /content/min-dalle/pretrained/vqgan/flax_model.msgpack\n",
|
|
|
|
"pip install torch flax==0.4.2 wandb\n",
|
|
|
|
"wandb login --anonymously\n",
|
|
|
|
"wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0\n",
|
|
|
|
"wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14\n"
|
2022-06-28 00:58:17 +00:00
|
|
|
]
|
|
|
|
},
|
2022-06-29 13:43:42 +00:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
2022-06-30 15:25:24 +00:00
|
|
|
"metadata": {
|
|
|
|
"id": "kViq2dMbGDKt"
|
|
|
|
},
|
2022-06-29 13:43:42 +00:00
|
|
|
"source": [
|
2022-06-29 14:37:12 +00:00
|
|
|
"### Load Model\n",
|
2022-06-30 14:02:08 +00:00
|
|
|
"Note: high RAM runtime is required to run the mega model"
|
2022-06-30 15:25:24 +00:00
|
|
|
]
|
2022-06-29 13:43:42 +00:00
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2022-06-30 15:25:24 +00:00
|
|
|
"execution_count": 2,
|
2022-06-29 13:43:42 +00:00
|
|
|
"metadata": {
|
|
|
|
"colab": {
|
|
|
|
"base_uri": "https://localhost:8080/"
|
2022-06-30 15:25:24 +00:00
|
|
|
},
|
|
|
|
"id": "8W-L2ICFGFup",
|
|
|
|
"outputId": "e0a20682-6359-44be-b9d3-f4712e14900e"
|
2022-06-29 13:43:42 +00:00
|
|
|
},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
2022-06-30 15:25:24 +00:00
|
|
|
"output_type": "stream",
|
2022-06-29 13:43:42 +00:00
|
|
|
"text": [
|
2022-06-29 14:37:12 +00:00
|
|
|
"reading files from pretrained/dalle_bart_mega\n",
|
2022-06-29 13:43:42 +00:00
|
|
|
"initializing MinDalleTorch\n",
|
2022-06-30 13:39:14 +00:00
|
|
|
"initializing DalleBartEncoderTorch\n",
|
|
|
|
"initializing DalleBartDecoderTorch\n",
|
|
|
|
"initializing VQGanDetokenizer\n"
|
2022-06-29 13:43:42 +00:00
|
|
|
]
|
|
|
|
}
|
2022-06-30 15:25:24 +00:00
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"import os\n",
|
|
|
|
"os.chdir('/content/min-dalle')\n",
|
|
|
|
"from min_dalle.min_dalle_torch import MinDalleTorch\n",
|
|
|
|
"from min_dalle.min_dalle_flax import MinDalleFlax\n",
|
|
|
|
"\n",
|
|
|
|
"mega = True #@param {type:\"boolean\"}\n",
|
|
|
|
"torch = True #@param {type:\"boolean\"}\n",
|
|
|
|
"model_class = MinDalleTorch if torch else MinDalleFlax\n",
|
|
|
|
"is_reusable = True\n",
|
|
|
|
"model = model_class(mega, is_reusable)\n"
|
2022-06-29 13:43:42 +00:00
|
|
|
]
|
|
|
|
},
|
2022-06-28 00:58:17 +00:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
2022-06-28 15:01:31 +00:00
|
|
|
"metadata": {
|
2022-06-28 15:05:59 +00:00
|
|
|
"id": "c52TV1GbBNgS"
|
|
|
|
},
|
|
|
|
"source": [
|
|
|
|
"### Generate an Image"
|
|
|
|
]
|
2022-06-28 00:58:17 +00:00
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2022-06-30 14:02:08 +00:00
|
|
|
"execution_count": 4,
|
2022-06-28 00:58:17 +00:00
|
|
|
"metadata": {
|
|
|
|
"colab": {
|
|
|
|
"base_uri": "https://localhost:8080/",
|
2022-06-29 17:55:23 +00:00
|
|
|
"height": 528
|
2022-06-28 15:05:59 +00:00
|
|
|
},
|
|
|
|
"id": "nQ0UG05dA4p2",
|
2022-06-30 14:02:08 +00:00
|
|
|
"outputId": "c5174520-66cf-43db-a51a-5ab260b21f99"
|
2022-06-28 00:58:17 +00:00
|
|
|
},
|
|
|
|
"outputs": [
|
|
|
|
{
|
2022-06-29 02:12:36 +00:00
|
|
|
"name": "stdout",
|
2022-06-30 15:25:24 +00:00
|
|
|
"output_type": "stream",
|
2022-06-28 00:58:17 +00:00
|
|
|
"text": [
|
|
|
|
"tokenizing text\n",
|
2022-06-29 14:37:12 +00:00
|
|
|
"['Ġa']\n",
|
|
|
|
"['Ġcomfy']\n",
|
|
|
|
"['Ġchair']\n",
|
|
|
|
"['Ġthat']\n",
|
|
|
|
"['Ġlooks']\n",
|
|
|
|
"['Ġlike']\n",
|
|
|
|
"['Ġan']\n",
|
|
|
|
"['Ġavocado']\n",
|
|
|
|
"text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]\n",
|
2022-06-28 00:58:17 +00:00
|
|
|
"encoding text tokens\n",
|
|
|
|
"sampling image tokens\n",
|
|
|
|
"detokenizing image\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
2022-06-30 15:25:24 +00:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAEAAElEQVR4nFz9WbctSXIeiH1m5h4Rezj7nHPPnXOsrEKhMBEcxGYvtVrNltbSi/6H3vQHtPQkkdJLc0lsruYS2ZS6m6I4gCSaxECgCTZJAAQIkCDGQhWAysrMyunOZ9pjRLib6cHcY5/kBarq3jPE9nC34bPPBqff+Le/NiSSPnMTLGfADMrEChCTGQgGEGAEAmj6PzODwWBM5F8BAP8f899TYoL/kH+DyMwIsOmnCTCoKQAmMqtPIgLM1J9jxPWL5r8HA/zr/nAYiACQ4fh8AhHDzJ8DJioL9LfxxxgZlGhaNvzv5TlWvkNEZqZq5Q2pvuu0HiMzZS6PMdQdIYPVzfM3sPJivh6bdg4ww/TRdYN9PfWB/nHsX/I/RDAw6g9NL2LTeZjVZUw7VM7DCGTlWTCzaf0KMJgI6i/jm0kAoKaqRgZiQt0AIjJTNYMSyJjLEtTATAAZdNoHsvJ6qgYYE/tzCOYS4u8F03LugFr5kSJK5gdS9k3NYMbEVg5VCewfpuW9TA0EKAEJLJxVGUZdFwYlyonaOGYQiZkRBy17RagqQPBtrhLgSyQQkYF0EgT4j5gRgYTqcdZn3Pm5SYwBkDAol1/2HUI9JIDIyOqr0/E3CVTln7gqlUFdUagoSZEdIq3q8ZVFFHGHwYgMIKtqZmXZNgm1lfdF/di7z4GR6LQewIos1n+jfg5gJERcN820LMJwfHMoDAQiaN07/cpz/BeIuPyKlg+kaZOKEZqsDeO4Hv7KERT7QTDyg3XFICvKay46ZHVhDIBssmxU9AcAuEpHWT+h/J4dRQfm/zQGE7umAFDUAyWAyGCTBhIBxMVAGIgsux67QSnnW9fg4ueWl0Bk9eiJlI2QkY2ph861C0GDYdQEghb5LTtkRASjujBfTJUylJ91W0RVXK38xcWkiFDduiIS9Rjr27iB4fo51aT7htGdr9Bk3O0oCXdE4mjzXIlsEmwqkky+Vhx/0PfK97QqlN1Z5OTAzIWEy0l+RQsmo+YmrDojV8Lq1Kw+y4XNjYaZEbgIBwFGZR+tCLIZFfNfDqTujIvFpKyGO7bV10dfcbVf2aI7f76ioC6WdFTIO/tZ1qvFQ7OvHWXH6mdW+TGYEXF5OqHYS1hFDgYjdkxQ94ruWNHio9VQzmVSI2OQEYr9MEz6UtHKHc98x8MVt0gZRi6fkWjUgUlNja0aRzIm4uqraFJsYtQ3KK6sfBL5x6r/chE1MiKDKZFNrv8o7WT1Xa1KFowAniSqOHMqP2lU94bqIRdVKX4cVbVc+aaPsuPHuSWiuvCq3A5tiI56YUd592O0O/IIK08rJpPr4ZSf5SJNLrVHKFTPqpz8pH1FOaqew0BW7Fd5YhEWIxBDJ5Tk63GvhgKB6nOO2+jPqUbLrTIR3/1yAbYgM1PwHSd7dIy+y2ZQcAUlxbpPH+GC7jZOAaOCjF1otRxiPVCrkgP2r4OIwCD2h1T54fqidbllJW7TXfb8L1zXXTCrGZmRotoCg4FMoWA1M0CTWQwcgGwKYmImPyHSoxhx0V2URTIctx3NvRIR+K6u4/iHqllyEWCqto4mb1UFSB1TFsfKxFasERcRZzIzIjIYF3NQsE5RcTNmMiPi4pW/YgVdhIgd9ROZ2mQ2jvpjWtTJFERERKbFx/qDqEjvV963wrMqy8VsGJeYxzHEFEUBUCgREWNCssXm2R09NyoW19+QeXKHBvCEksyMid03WfEMVU2qXzUzUtcTc/V3lMLEqi5tADkyYnNQDWNmdXElm3ZGi9ViIiKwA3Zmx+7FLjFXy4eCzl1sfD+pvi9gyK6SbNXHVjtO1eUVv6FmHnX43rgDhKGcqJowab6zg+5q/ReJmJgUUpXXTAKBiQIAYgOZGti1U4pxdM9aAYRvo0Mj9iOios5WfZRWUwAYwGyYxJ+I/MSIUewqHW1JEb4CHoiLZFfD6DIzwdXpGK2iKnfbNbItdpeI3RiUFRXrSQplu+NvuaKWYj+OgOTOc+gYJJjb6fq+Lo0GlG+7B5yUvdigY/xAd5SkWEWlukqzwj1UE1LRepE8/30BARCqoa0CpK4wNPkbojsohuCH7zDQpodWZwJ4VM2uN2YKoiPkNd9LAoG5wMAKBYmJQMRcV+P7WbFt3ZIC1+4IcYFGUyzh+wlgMnq+0wo2D0nKBlRcWTTFGGwGVSOegJ4W3Eh1QwAwq2YGaREkTaMGMiUjY0ixb1xkjAoqUkz+H+7IqKgXps+om0RUbeTksAoRpHZELnedldb3ObreIqaMI35GNQQeD6P4lumJNr3T8aOtxrtUl1lWVG24H4O7tioMdyTPnzGtBzyhQVQ4U55DVMMgoqo5Vk6zfMT04gUfF4vgXyGadpEKaKyvUISjoK1pPVafVCWq8nA2CSGzTdGXnyKq8z1KRLVCBgDCPBlgl+Yac0yu299hkq0iBXb3URP/Us2VvzwfhYIAqHoUgfrygJrzidWMOQYmMNigOvnqojl3ocbE3lVI7V4FFaoA7C7diGFqRGRkBpPoSMO9sqmZmWVzEHc8/PqBdAzP70KzEgvxJL8e5PhiajDL5T3vxKDuE4uNrUJYPecksVOIUAS1ilAxyFaD5bJOmlZRw97qYsjJC0c2dxZRjLo5BcB3dRpfiWGovoW/PE8OyqCOqKA1Zir4pq63LPloeYl836YwuiKCSnQUgpOLgrkgWl3PtDsEQOEECO6EokXgp/i0yE6hOWl6tyJkZU/IAU/12lRXfOcV4AiYYGSKuur6TkZQqAtXAXyODSbjgymeAUDM/m+r2mbVDvn762RnmKa4qB5u+TsRETPRxFfejQnrK3gEo0bEbmM80FClQExQJV+McxK+f1Y8LXH5NJrO3xyO+7OsYo477tbAQjCYFp9IVEB8gSdFyDwsAzOZGhkK0WrV6bp1N6jZZMeZYGrEVLhmjwmqby12t+w7zDwaq4i64Gx3I2QFWPtztKzTtb3oA8HM37dICcNXQ+U0SyqCuboMR0TVHDprMpn443OKb3WOoroOXy2qDvhvUvU5mIiJYnCZyKBkUK0MWvVL5HkJsBVhrO7IxYanfxTXW05Q61e4UmM6Oah6wmWDDFqtwyTfU7yihTaujmLKgwAGZJgLHBFVOmxyRO7FXFU1mwHVcDFgd88XJZ4De0hX3xAVObicUPUJVX+rzwEBFty8a91uqxpNDAI7miIiEpqMkxaoaBMBU+FmQaFTkGsV+Lkvd6RYHby6Rk4QsdoIWGUuy7ZM6bEJOzHD49UKYArDSUQy0YCVhJkOp2KQ+u7mjr5sP9XX94CCj8ZNoRPCcK/v0GIyM1xNudHdsM39ANXXOq5nWkaRvClir/C6vCqRB1zw2NHdhksJ1x2xSlhXxAIcAXsJ3q2cZlEMKdbKDYcBBLESpLve01FixffVk1b+KZOakcGI2VQNHvVQDcP5K87KbZxVU8NuZ8v5CZGWfF4RJjIuuS+PF4tYGohIikoVwaj8YJV9mmCsVXx2lMAjDPP8iAEI5i6cQjmOEpg4qjHSKpf6FfTq9t/smJ/iI3ir8lYN1ZGdPfopfx0/xmo7q3hXkEKT85380B0sSoRip3EXEVazBIfHVF6oWjtisyov5Tkwq0b4zht4RE4GdRxan1P2CFwPzOqTzQFQfaPiOKv19fW4GBZ0Uj7+TqbAX0MnXmFyVu4oFEf9QHmMHaOoCdwU44eCJty4Hy24Fg7JLQ0IILsT2JYgBgoImSoM4LIAK0pJmo08IK4m1TejgJPsyuNGFwBpobg8uhZfskIJUHOPUh7ii5qAbv1Tvqha0PDk9KpVIEzCeJSQaduK/bSv7AgICMWJVcvnnEaBhagJsfozJcouBvSOHGNKEx/JAcA8PrsDueHGhDydX11epUemD6vKNinEXatrU7Klglkqr+YGqpIr5flFOotddPW9g2xt0iYiTPQmat6RCjF/l6WpL1SOySY7Oj0HdqQSyg9MamlcFNntPVV3UGMmAtw1gcCe7ylIsrgXm/AvCCRMZoUMnPynO64CqGjaMFctEoIxSD3gsfrpJ
|
2022-06-28 00:58:17 +00:00
|
|
|
"text/plain": [
|
2022-06-30 14:02:08 +00:00
|
|
|
"<PIL.Image.Image image mode=RGB size=256x256 at 0x7EFDFB09A290>"
|
2022-06-30 15:25:24 +00:00
|
|
|
]
|
2022-06-28 00:58:17 +00:00
|
|
|
},
|
2022-06-30 15:25:24 +00:00
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
2022-06-29 17:55:23 +00:00
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
2022-06-30 15:25:24 +00:00
|
|
|
"output_type": "stream",
|
2022-06-29 17:55:23 +00:00
|
|
|
"text": [
|
2022-06-30 14:02:08 +00:00
|
|
|
"CPU times: user 7.4 s, sys: 11.9 ms, total: 7.41 s\n",
|
|
|
|
"Wall time: 7.31 s\n"
|
2022-06-29 17:55:23 +00:00
|
|
|
]
|
2022-06-28 00:58:17 +00:00
|
|
|
}
|
2022-06-28 15:05:59 +00:00
|
|
|
],
|
|
|
|
"source": [
|
2022-06-29 17:55:23 +00:00
|
|
|
"%%time\n",
|
|
|
|
"\n",
|
2022-06-29 14:37:12 +00:00
|
|
|
"text = \"a comfy chair that looks like an avocado\" #@param {type:\"string\"}\n",
|
|
|
|
"seed = 10 #@param {type:\"integer\"}\n",
|
2022-06-28 15:05:59 +00:00
|
|
|
"\n",
|
2022-06-29 13:43:42 +00:00
|
|
|
"image = model.generate_image(text, seed)\n",
|
2022-06-28 15:05:59 +00:00
|
|
|
"display(image)"
|
2022-06-28 00:58:17 +00:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"accelerator": "GPU",
|
|
|
|
"colab": {
|
2022-06-30 15:25:24 +00:00
|
|
|
"authorship_tag": "ABX9TyM+gbKB3WRvBMMXfvaRxqOY",
|
2022-06-28 15:05:59 +00:00
|
|
|
"collapsed_sections": [
|
|
|
|
"Zl_ZFisFApeh"
|
|
|
|
],
|
2022-06-30 15:25:24 +00:00
|
|
|
"include_colab_link": true,
|
2022-06-30 14:02:08 +00:00
|
|
|
"machine_shape": "hm",
|
2022-06-30 15:25:24 +00:00
|
|
|
"name": "min-dalle",
|
|
|
|
"provenance": []
|
2022-06-28 00:58:17 +00:00
|
|
|
},
|
|
|
|
"gpuClass": "standard",
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"name": "python"
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 0
|
2022-06-30 15:25:24 +00:00
|
|
|
}
|