min-dalle-test/min_dalle.ipynb

174 lines
1.3 MiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-04 14:21:12 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-04 20:06:14 +00:00
"execution_count": null,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-04 20:06:14 +00:00
"id": "ix_xt4X1_6F4"
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:06:14 +00:00
"outputs": [],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-04 20:06:14 +00:00
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-04 20:06:14 +00:00
"outputId": "73aa19d1-f2c6-4054-8d88-e882d21ce390"
},
"outputs": [
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
"text": [
"initializing MinDalle\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
"initializing DalleBartEncoder\n",
"initializing DalleBartDecoder\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-04 20:06:14 +00:00
"from PIL import Image\n",
"from IPython.display import update_display\n",
"import numpy\n",
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
2022-07-04 20:06:14 +00:00
"Note: reduce the grid size if you run out of GPU memory. 4x4 has been tested to work on T4 and P100 (with intermediate_image_count = None)"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-04 20:06:14 +00:00
"execution_count": 2,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-04 20:06:14 +00:00
"height": 801
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-04 20:06:14 +00:00
"outputId": "3266cc97-4c97-49c0-c15e-f89fa7373c46"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-04 14:21:12 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
2022-07-04 20:06:14 +00:00
"<PIL.Image.Image image mode=RGB size=750x750 at 0x7F982862A510>"
2022-07-04 14:21:12 +00:00
],
2022-07-04 20:06:14 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAu4AAALuCAIAAAB+fwSdAAEAAElEQVR4nGT9abBs2XUeBn7fWmufczLzTm+sV1WvRqCqUIXCUACIkSRAiuAoUqJIarIkh9oh2bItWwq1FIqQqO6W2yG7ZUe41e122+1WU7Jsk5IoihQHkQRJjARAojAWUPP86r1687tTZp6z91qrf+SD6CF/3RuRkTci79nnfOubFv+rT/6r5Vrb9YO0FEqLMc1UkF59amkdBJJtriGp6xAAHdgkWrg4LFM4ejphIgYHIVCM49GiSIG71HGMTgefatGoGekqqNvz2TQtV22MxCDSwxtYkV2omjicifVqFNahS5GSqtkETYLRYl0Y84KI/domYlfCgJUwiR5og1nGtPZbmTTMeza19dQoHATI8ASRbH6wtdDwcWoB7BQZ2I4bly2ykINlZ3Y8LdcNEtsmjXIodIFkltooql0vrY3iqRYmOfRW0bxaG2M2LyYMoIVGuErp+5KDNSeVpl03zI+PphMndtcHaxSZ6pos4ZlRpczaOmHZqTDDWcm1tJptjJzc1wagkTAxWjEr4nX0OmpBrROEVCQ8XFqroipWITaup2yiamrd5NlalCJIF6F7ptNMka2ltzGYFFXPzOgAi+YxrYQNKtXTq4oU0UyiThMDfUHGavIWbRDOqEfK8KoUt6BQJhxlQ2tzBVhuSEqquTuqZGTLKWM978XjyN01e3iXDFHP2FxTkyUbc6wuqUPJRHggCIUgM5He0BxWQKAGZnPkCnSMBTGiOzH7vbf99TV1PFpObIiok/dDD/XqnqLpYqlKqeER06zTZLaERjEo6ckgzJBmQeS6VQ/2poNkJFIFDqFQkInWMnwym4paQwo6zS7JlhFtnA+KWE+oWbsiPSJomSwCqrgpxjataxS1QaEIBzIoVIDITEerY9c3lfQk2Ut0GWz0jGqqhS3lePKMaa6gWSPJ7IOTpYjJsq5aiKHvtDmOPDrVLtMyDEB4E52KRfUpQhWDICFTCDOLSRNpAGuLyak2GASJSE6tFVVf1715X6JmPazLo3np1uumWoZh645zZ7oZyK5Wy/SuD0HNnMgerXTcFkcbD624qXkTK4uuLMZ1vXLp1oXXDrZPnTv38APzbjtjqu04jg7cxv6OHcfR8fVpeXix5bTofVZQuMVIm5EDs6CrcnzYYW4HSzagcWvk7qQlFosaq+V0c310y+txMaevlqvmLYCokHHyaDBqA+rUMl2KFxWhL8cpwor2RAORCVW0Gg5ztMwYV61TEV1msIkCXTQNSvWptXVfImLdgFrn8D7SQzPQIVNYI2NZmwd77QZ197FBAxbBDCpZI6c69ZaCXLZVsZmEikhTqS57/YnPXHLcGNEErOgVyyWQKEA4So8AQiAEEwCQQIMArgiBNDAgHUCAQCAqPEEDGxAQRQqoCEcSVIxHmBmaA4T1CAMJBBBIQVRIAIADJCRBQg0AlEAiHSBEEA0IGMGEAxQIIYkYkSNQMTqQEB1QFswFcwvaWVe7Xudb0nf73m6slicX82uHR6sm2fzEqTtWY91a7JgEoJC5uiaGMiy67fm0Wu1fevPg1pXqtNOnt87dN96w1e89g+kY3QLPfAO3Pgsc4tHHH3zPww+dnZ8o3enTfna7O3GqzrYKj4diu4szZ3To16uxnHvLU1+7+Yl/8Vuf+Z/+B8ABACCJP3jl/+JnmGKhmCpWiUzg1OLBH/3oT/yZ73zHY/cvV91rF2+8eW154+jCtf2jG/tHy2nV1q2tJ3MdgdXRilZcmQFOVco8iKmi007AMatTAI0pKJKNUEJAKOmBRARGV0pnDMLBJAKSAWQiBXWUPqkZ3jI7lR6koyGBSACgg4loCNUUhbdMG+aiNk1OkBCQ3lZ90dpWgcymaEpJlM0dD4IWiEiHO2Jz+4y0AitAwhMgIGgNBciGOiELMmAFQ49WwNN2Y9mjHc5V1lZq9YxZz6LjVNuq74YjpfbDXpWdGm/mchLMgn3SIytz17ou5ahVOMUKmEYl5biuRboSlOAao0A62lBmh21ZuhLZRFiOfCWyUlno7G5fXKsHY5nEpYO6p6M6UvvOQpQZwQovqR36lY/JrghzmhxJoEiZcdvDqjS4OjVbjUgXhJWd3CpTXfFAMWRaojFEOLTW+rKb01HERO0stwcfAARypjONipjaukVSy2zWTkhdVTumrpEF0M62jH1MS03ppOukKZpU8amZdbNS1BXugVZQ0FaLxbY7MUIyh7JozSyGeSb36zy7ycdOYmyHUZfaocRRZ7laL72O4a7ZWnhAJEFUyyqqJlpExAch67QSLodOpmks4UW0enp4NA5ahAFvbfTi3pd5Jr2l1FxYrw6im6YmVCsFbWru5mYUkNFcI2l9S83MyCSWbawlbV56QKLF1LJnbwyvNbIajqXbtpilRyIApIRE9WRLUsauuyNzUkLCplYke4Yhw1vRbhFtGe3YBBJqup3RkBGU5jUhgRrpov1gW+Zj9WMRRSogDAAZHEVSiGyYKWQkrDSZDGDBuPbr6Hl4WDRgpTUftrZNmG0SjQZYEYY1T8nsu95UWroSPWaz7BtqExeQrMzIjFZUbbbVZhnVdZ1QJVUyEBVNuk4aulIyCVBplmVsnpDOaI41W2Z2NtuSbTiWOmaqhCOztToRmA9DLIbAyONgCgqRCY9EmqpAjEAmjNkrLCmejVo60rGMYEPth60tzMZ6IGrNndIzvTmaGVVmdY5ceyfaCthLFKG1aFqoJLNRomjf53YBRzmsSGVP1BTJFlSoqElvWYCMEFEPyNbQbSU6ranod/a2F9Jp5zHtbp3zpqWnA+i78DpfcJquiFW0hiPrMtXIspzqsjoNC5uKhu5x0YRv3tq3rZ1zd5/qdk9ffO7l8eb1u06zO3X6leX4wtdevvjNr/d43YaxVUsNZjeYzWbbd5zdO3XCZt18iLa+1U6e3Ds0Pax73e49W4vTx21Vj0dpafOB0nJ5VGtKURF4NFQ3LVY0HT76IB0VLmPSxxY0zmxR2GebWmZmODwZimwhIyZdzCwS2VIUFEItbIr0yK50ytqImlTpTsSJlu2mrgGFT0DLbG5BWyx8x2o7lkZRrSqAkq4+5cQuTSwcs9mWsR98GDEh13VY7Nc14hTkCGWCzrFeod+FKXJEVkRCDapAwCtCUDrIEt6gM+ocucrSEIKoiA3sEPQztAFtDVkiBaFgIhOdYazoO1AgBZLQglQQCIEntKAQvkIAZQtdjxjhAQ8UIjeP+Q5dARyYEA4oRNABAiDQGqQHCrxhV1AUoxvZk0pqyT0OYxmOFlvLoV8OllH3h1kHu3vWbYVs9wvEtD2UKTagScK2VbdFh8Y8Xh6NOl1fHuLVy+3Kwa1Xr6OeghIZuL6PCcA5bO3sPXHv7rnFLLk9+CA2H/qTs3N3zvvjsNnu9j2P7fny1NPPrf+z/+inf//XvrKBKSIUUtXe9uBb77hnV7u2PFweHF5bj+uaTSSDqOM0G7i3JePh+NoV3Lx5/NLP/Orf/5nP3PvD7/vhP/2+xx65d3Z6hSwHbiacL4uZyszGsdph293amSKXtUain+2Gcarc2i5sOtVJWZwCWjeUhgxFdY9gURU0ae2AWXb2ZiJttU40JeES6QEPMuHoy4zM8JVyZvM+uhaxJlvQuiKsM28HOWWxYThxsupRW7ut4TJOkyGRRnBsrZN+NsK1eK5N5tvDHN6OsN4cdxKZdaRwGBY2z9ZqjkIiFZkuGYB7TTVNdRhKgc3gBBNjYnsGP7ScmkzOvpgCyIA7EwYpPaEdmHVqkFq6vpmIM3JEJLNQQ6OFmwCpCGRkywZi3gnSN7hN2yBKQCaZuiKKbCoJLDV94ha7TqbKuaGbM5JonqmRoEK74qxgCIKMcOSEqpaFKVlBV2xBNudpTfOeDCaDLSIbVIdO2oLSKIIhYcyWOiVTQ2eLbO1YhMlC6QrX8HSZOoG0rEiwikCyCEXLOqMGBkZmICmdkEzSRMwzR0+VyOZEUylKj808ElBlN8xTmBARK5qgFqXXdZmvQazWq+W4ntarw/VqfXh8c3/aX0lkKCrq8
2022-06-28 00:58:17 +00:00
},
2022-07-04 14:21:12 +00:00
"metadata": {}
2022-06-29 17:55:23 +00:00
},
{
2022-07-03 22:40:27 +00:00
"output_type": "stream",
2022-07-04 14:21:12 +00:00
"name": "stdout",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-04 20:06:14 +00:00
"CPU times: user 34.4 s, sys: 897 ms, total: 35.3 s\n",
"Wall time: 36.5 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-04 20:06:14 +00:00
"grid_size = 3 #@param {type:\"integer\"}\n",
2022-07-04 12:05:55 +00:00
"seed = -1 #@param {type:\"integer\"}\n",
2022-07-04 20:06:14 +00:00
"intermediate_image_count = 8 #@param [\"2\", \"4\", \"8\", \"16\", \"None\"] {type:\"raw\"}\n",
"\n",
"display_size = 750\n",
"image_shape = (display_size, display_size, 3)\n",
"zero_image = Image.fromarray(numpy.zeros(image_shape, dtype=numpy.uint8))\n",
"display(zero_image, display_id=1)\n",
"\n",
"def handle_intermediate_image(row_index: int, image: Image.Image):\n",
" image = image.resize((display_size, display_size))\n",
" update_display(image, display_id=1)\n",
2022-06-28 15:05:59 +00:00
"\n",
2022-07-04 20:06:14 +00:00
"image = model.generate_image(\n",
" text,\n",
" seed,\n",
" grid_size,\n",
" intermediate_image_count,\n",
" handle_intermediate_image\n",
")"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-04 14:21:12 +00:00
"provenance": [],
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-04 14:21:12 +00:00
}