min-dalle-test/min_dalle.ipynb

189 lines
140 KiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-06-29 02:12:36 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
"### Download models and install dependencies"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-06-29 17:55:23 +00:00
"execution_count": null,
2022-06-28 00:58:17 +00:00
"metadata": {
2022-06-28 16:16:44 +00:00
"id": "ix_xt4X1_6F4",
2022-06-29 17:55:23 +00:00
"cellView": "code"
2022-06-28 00:58:17 +00:00
},
2022-06-29 17:55:23 +00:00
"outputs": [],
2022-06-28 00:58:17 +00:00
"source": [
2022-06-29 20:32:51 +00:00
"%%shell\n",
"\n",
"git clone https://github.com/kuprel/min-dalle\n",
"mkdir -p /content/min-dalle/pretrained/vqgan/\n",
"curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output /content/min-dalle/pretrained/vqgan/flax_model.msgpack\n",
"pip install torch flax==0.4.2 wandb\n",
"wandb login --anonymously\n",
"wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0\n",
"wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14\n"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
"source": [
2022-06-29 14:37:12 +00:00
"### Load Model\n",
"Check \"reusable\" if you are using a high RAM runtime. This allows the model to be reused for multiple text prompts."
],
"metadata": {
"id": "kViq2dMbGDKt"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"os.chdir('/content/min-dalle')\n",
"from min_dalle.min_dalle_torch import MinDalleTorch\n",
"from min_dalle.min_dalle_flax import MinDalleFlax\n",
"\n",
2022-06-29 14:37:12 +00:00
"mega = True #@param {type:\"boolean\"}\n",
"torch = True #@param {type:\"boolean\"}\n",
"reusable = False #@param {type:\"boolean\"}\n",
"model_class = MinDalleTorch if torch else MinDalleFlax\n",
"model = model_class(mega, not expendable)\n"
],
"metadata": {
"id": "8W-L2ICFGFup",
"outputId": "b2bba56a-2904-4b83-9c10-b1f7518a9737",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
2022-06-29 14:37:12 +00:00
"reading files from pretrained/dalle_bart_mega\n",
"initializing MinDalleTorch\n",
"initializing DalleBartEncoderTorch\n",
"initializing DalleBartDecoderTorch\n",
"initializing VQGanDetokenizer\n"
]
}
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
"### Generate an Image"
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
"execution_count": 7,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-06-29 17:55:23 +00:00
"height": 528
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
"outputId": "2e7111bc-0de3-4ddc-bfa4-c7a8131ce1dd"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-06-28 16:16:44 +00:00
"output_type": "stream",
2022-06-29 02:12:36 +00:00
"name": "stdout",
2022-06-28 00:58:17 +00:00
"text": [
"tokenizing text\n",
2022-06-29 14:37:12 +00:00
"['Ġa']\n",
"['Ġcomfy']\n",
"['Ġchair']\n",
"['Ġthat']\n",
"['Ġlooks']\n",
"['Ġlike']\n",
"['Ġan']\n",
"['Ġavocado']\n",
"text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]\n",
2022-06-28 00:58:17 +00:00
"encoding text tokens\n",
"sampling image tokens\n",
"detokenizing image\n"
]
},
{
2022-06-29 02:12:36 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
"<PIL.Image.Image image mode=RGB size=256x256 at 0x7F47EE696D50>"
2022-06-29 02:12:36 +00:00
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAEAAElEQVR4nFz9WbctSXIeiH1m5h4Rezj7nHPPnXOsrEKhMBEcxGYvtVrNltbSi/6H3vQHtPQkkdJLc0lsruYS2ZS6m6I4gCSaxECgCTZJAAQIkCDGQhWAysrMyunOZ9pjRLib6cHcY5/kBarq3jPE9nC34bPPBqff+Le/NiSSPnMTLGfADMrEChCTGQgGEGAEAmj6PzODwWBM5F8BAP8f899TYoL/kH+DyMwIsOmnCTCoKQAmMqtPIgLM1J9jxPWL5r8HA/zr/nAYiACQ4fh8AhHDzJ8DJioL9LfxxxgZlGhaNvzv5TlWvkNEZqZq5Q2pvuu0HiMzZS6PMdQdIYPVzfM3sPJivh6bdg4ww/TRdYN9PfWB/nHsX/I/RDAw6g9NL2LTeZjVZUw7VM7DCGTlWTCzaf0KMJgI6i/jm0kAoKaqRgZiQt0AIjJTNYMSyJjLEtTATAAZdNoHsvJ6qgYYE/tzCOYS4u8F03LugFr5kSJK5gdS9k3NYMbEVg5VCewfpuW9TA0EKAEJLJxVGUZdFwYlyonaOGYQiZkRBy17RagqQPBtrhLgSyQQkYF0EgT4j5gRgYTqcdZn3Pm5SYwBkDAol1/2HUI9JIDIyOqr0/E3CVTln7gqlUFdUagoSZEdIq3q8ZVFFHGHwYgMIKtqZmXZNgm1lfdF/di7z4GR6LQewIos1n+jfg5gJERcN820LMJwfHMoDAQiaN07/cpz/BeIuPyKlg+kaZOKEZqsDeO4Hv7KERT7QTDyg3XFICvKay46ZHVhDIBssmxU9AcAuEpHWT+h/J4dRQfm/zQGE7umAFDUAyWAyGCTBhIBxMVAGIgsux67QSnnW9fg4ueWl0Bk9eiJlI2QkY2ph861C0GDYdQEghb5LTtkRASjujBfTJUylJ91W0RVXK38xcWkiFDduiIS9Rjr27iB4fo51aT7htGdr9Bk3O0oCXdE4mjzXIlsEmwqkky+Vhx/0PfK97QqlN1Z5OTAzIWEy0l+RQsmo+YmrDojV8Lq1Kw+y4XNjYaZEbgIBwFGZR+tCLIZFfNfDqTujIvFpKyGO7bV10dfcbVf2aI7f76ioC6WdFTIO/tZ1qvFQ7OvHWXH6mdW+TGYEXF5OqHYS1hFDgYjdkxQ94ruWNHio9VQzmVSI2OQEYr9MEz6UtHKHc98x8MVt0gZRi6fkWjUgUlNja0aRzIm4uqraFJsYtQ3KK6sfBL5x6r/chE1MiKDKZFNrv8o7WT1Xa1KFowAniSqOHMqP2lU94bqIRdVKX4cVbVc+aaPsuPHuSWiuvCq3A5tiI56YUd592O0O/IIK08rJpPr4ZSf5SJNLrVHKFTPqpz8pH1FOaqew0BW7Fd5YhEWIxBDJ5Tk63GvhgKB6nOO2+jPqUbLrTIR3/1yAbYgM1PwHSd7dIy+y2ZQcAUlxbpPH+GC7jZOAaOCjF1otRxiPVCrkgP2r4OIwCD2h1T54fqidbllJW7TXfb8L1zXXTCrGZmRotoCg4FMoWA1M0CTWQwcgGwKYmImPyHSoxhx0V2URTIctx3NvRIR+K6u4/iHqllyEWCqto4mb1UFSB1TFsfKxFasERcRZzIzIjIYF3NQsE5RcTNmMiPi4pW/YgVdhIgd9ROZ2mQ2jvpjWtTJFERERKbFx/qDqEjvV963wrMqy8VsGJeYxzHEFEUBUCgREWNCssXm2R09NyoW19+QeXKHBvCEksyMid03WfEMVU2qXzUzUtcTc/V3lMLEqi5tADkyYnNQDWNmdXElM70DA1BCPwI7YGd27F7sEnO1fCjo3MXG95Pq+wKG7CrJVn1steNUXV7xG2rmUYfvjTtAGMqJqgmT5js76K7Wf5GIiUkhVXnNJBCYKAAgNpCpgV07pRhH96wVQPg2OjRiPyIq6mzVR2k1BYABzIZJ/InIT4wYxa7S0ZYU4SvggbhIdjWMLjMTXJ2O0SqqcrddI9tid4nYjUFZUbGepFC2O/6WK2op9uMISO48h45Bgrmdru/r0mhA+bZ7wEnZiw06xg90R0mKVVSqqzQr3EM1IRWtF8nz3xcQAKEa2ipA6gpDk78huoNiCH74DgNtemh1JoBH1ex6Y6YgOkJehzNgEJgLDKxQkJgIRMx1Nb6fFdvWLSlw7Y4QF2g0xRK+nwAmo+c7rWDzkKRsQMWVRVOMwWZQNeIJ6GnBjVQ3BACzamaQFkHSNGogUzIyhhT7xkXGqKAixeT/4Y6Minph+oy6SUTVRk4OqxBBakfkctdZaX2fo+stYso44mdUQ+DxMIpvmZ5o0zsdP9pqvEt1mWVF1Yb7Mbhrq8JwR/L8GdN6wBMaRIUz5TlENQwiqppj5TTLR0wvXvBxsQj+FaJpF6mAxvoKRTgK2prWY/VJVaIqD2eTEDLbFH35KaI636NEVCtkACDMkwF2aa4xx+S6/R0m2SpSYHcfNfEv1Vz5y/NRKAiAqkcRqC8PqDmfWM2YY2ACgw2qk68umnMXakzsXYXU7lVQoQrA7tKNGKZGREZmMImONNwrm5qZWTYHccfDrx9Ix/D8LjQrsRBP8utBji+mBrNc3vNODOo+sdjYKoTVc04SO4UIRVCrCBWDbDVYLuukaRU17K0uhpy8cGRzZxHFqJtTAHxXp/GVGIbqW/jL8+SgDOqIClpjpoJv6nrLko+Wl8j3bQqjKyKoREchOLkomAui1fVMu0MAFE6A4E4oWgR+ik+L7BSak6Z3K0JW9oQc8FSvTXXFd14BjoAJRg6K7Ag6CEZQqAtXAXyODSbjgymeAUDM/m+r2mbVDvn762RnmKa4qB5u+TsRETPRxFfejQnrK3gEo0bEbmM80FClQExQJV+McxK+f1Y8LXH5NJrO3xyO+7OsYo477tbAQjCYFp9IVEB8gSdFyDwsAzOZGrmPrXpDhQR2gGeTHWeCqRFT4Zo9Jqi+tdjdsu8w82isIuqCs92NkBVg7c/Rsk7X9qIPBDN/3yIlDF8NldMsqQjm6jIcEVVz6KzJZOKPzym+1TmK6jp8tag64L9J1edgIiaKwWUig5JBtTJo1S+R5yXAVoSxuiMXG57+UVxvOUGtX+FKjenkoOoJlw0yaLUOk3xP8YoW2rg6iikPAhiQYS5wRFTpsMkRuRdzVdVsBlTDxYDdPV+UeA7sIV19Q1Tk4HJC1SdU/a0+BwRYcPOudbutajQxCOxoiohIaDJOWqCiTQRMhZsFhU5BrlXg577ckWJ18OoaOUHEaiNglbks2zKlxybsxAyPVyuAKQwnEclEA1YSZjqcikHqu5s7+rL9VF/fAwo+GjeFTgjDvb5Di8nMcDXlRnfDNvcDVF/ruJ5pGUXypoi9wuvyqkQecMFjR3cbLiVcd8QqYV0RC8rvgCqDaFZOsyiGFGvlhsMAglgJ0l3v6Six4vvqSSv/lEnNyGDEbKoGj3qohuH8FWflNs6qqWG3s+X8hEhLPq8IExmX3JfHi0UsDUQkRaWKYFR+sMo+TTDWKj47SuARhnl+xAAEcxdOoRxHCUwc1RhplUv9Cnp1+292zE/xEbxVeauG6sjOHv2Uv44fY7WdVbwrSKHJ+U5+6A4WJUKx07iLCKtZgsNjKi9UrR2xWZWX8hyYVSN85w08IieDOg6tzyl7BK4HZvXJ5gCovlFxnNX6+npcDAs6KR9/J1Pgr6ETrzA5K3cUiqN+oDzGjlHUBG6K8UNBE27cjxZcC4fklgYEkN0JbEsQAwWETBUGcFmAFaUkzUYeEFeT6ptRwEl25XGjC4C0UFweXYsvWaEEqLlHKQ/xRU1At/4pX1QtaHhyetUqECZhPErItG3FftpXdgQEhOLEquVzTqPAQtSEWP2ZEmUXA3pHjjGliY/kAGAen92B3HBjQp7Ory6v0iPTh1VlmxTirtW1KdlSwSyVV3MDVcmV8vwincUuuvreQbY2aRMRJnoTNe9IhZi/y9LUFyrHZJMdnZ4DO1IJ5QcmtTQuiuz2nqo7qDETAe6aQGDP9xQkWdyLTfgXBBIms0IGTv7THVcBVDRtmKsWCcEYpB7wWP10E
2022-06-28 00:58:17 +00:00
},
2022-06-29 02:12:36 +00:00
"metadata": {}
2022-06-29 17:55:23 +00:00
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"CPU times: user 16.3 s, sys: 39.3 ms, total: 16.3 s\n",
"Wall time: 16.2 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-06-29 14:37:12 +00:00
"text = \"a comfy chair that looks like an avocado\" #@param {type:\"string\"}\n",
"seed = 10 #@param {type:\"integer\"}\n",
2022-06-28 15:05:59 +00:00
"\n",
"image = model.generate_image(text, seed)\n",
2022-06-28 15:05:59 +00:00
"display(image)"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-28 16:16:44 +00:00
"name": "min-dalle",
2022-06-29 02:12:36 +00:00
"provenance": [],
"authorship_tag": "ABX9TyMtWeoRGqaMmLjAXcNCf4AW",
2022-06-29 02:12:36 +00:00
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-06-29 02:12:36 +00:00
}