You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

253 lines
2.6 MiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2 years ago
"### Install"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "code",
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ix_xt4X1_6F4",
"outputId": "c72b1def-798b-401a-fad5-25e715f60dfe"
},
"outputs": [
{
2 years ago
"output_type": "stream",
"name": "stdout",
"text": [
"Sun Jul 17 11:27:12 2022 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 36C P0 26W / 250W | 0MiB / 16280MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting min-dalle==0.3.15\n",
" Downloading min-dalle-0.3.15.tar.gz (10 kB)\n",
"Requirement already satisfied: torch>=1.11 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.15) (1.12.0+cu113)\n",
"Requirement already satisfied: typing_extensions>=4.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.15) (4.1.1)\n",
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.15) (1.21.6)\n",
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.15) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.15) (2.23.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.15) (2022.6.15)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.15) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.15) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.15) (1.24.3)\n",
"Building wheels for collected packages: min-dalle\n",
" Building wheel for min-dalle (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for min-dalle: filename=min_dalle-0.3.15-py3-none-any.whl size=10562 sha256=3fa7138ec8078d7442a6a0a4c02653759b255d5853bfaa1710c5746399111aea\n",
" Stored in directory: /root/.cache/pip/wheels/92/58/05/36292d8d8de42c1d8afa532a27cbd026daaeab31819ac2b3c8\n",
"Successfully built min-dalle\n",
"Installing collected packages: min-dalle\n",
"Successfully installed min-dalle-0.3.15\n"
]
}
],
"source": [
2 years ago
"! nvidia-smi\n",
2 years ago
"! pip install min-dalle"
]
},
{
"cell_type": "markdown",
2 years ago
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model\n",
"`float32` is faster than `float16` but uses more GPU memory. Change the `grid_size` to 3 or less if using `float32`."
2 years ago
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2 years ago
},
"id": "8W-L2ICFGFup",
"outputId": "9ae4d35c-52f8-4733-9972-e46755cba899"
},
"outputs": [
{
2 years ago
"output_type": "stream",
"name": "stdout",
"text": [
"using device cuda\n",
"downloading tokenizer params\n",
"intializing TextTokenizer\n",
"downloading encoder params\n",
"initializing DalleBartEncoder\n",
"downloading decoder params\n",
"initializing DalleBartDecoder\n",
"downloading detokenizer params\n",
"initializing VQGanDetokenizer\n"
]
}
2 years ago
],
"source": [
"dtype = \"float16\" #@param [\"float32\", \"float16\", \"bfloat16\"]\n",
"from IPython.display import display, update_display\n",
2 years ago
"import torch\n",
"from min_dalle import MinDalle\n",
2 years ago
"\n",
2 years ago
"model = MinDalle(\n",
" dtype=getattr(torch, dtype),\n",
" device='cuda',\n",
2 years ago
" is_mega=True, \n",
" is_reusable=True\n",
")"
]
},
{
"cell_type": "markdown",
2 years ago
"metadata": {
2 years ago
"id": "c52TV1GbBNgS"
},
"source": [
2 years ago
"### Generate Images\n",
"\n",
"- `grid_size` Size of the image grid. Reduce this if you run out of GPU memory.\n",
"\n",
"- `progressive_outputs` Whether to show intermediate output. Adds a small delay and increases memory usage.\n",
"\n",
"- `seamless` Tiles the images in token space instead of pixel space\n",
"\n",
2 years ago
"- `supercondition_factor` Higher values result in better agreement with the text but a narrower variety of generated images\n",
"\n",
"- `top_k` Each image token is sampled from the top $k$ most probable tokens"
2 years ago
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
2 years ago
},
"id": "nQ0UG05dA4p2",
"outputId": "7faf570a-3444-4e6a-d117-a4e1d0b09ce5"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<PIL.Image.Image image mode=RGB size=1024x1024 at 0x7FA55DFEEE50>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABAAAAAQACAIAAADwf7zUAAEAAElEQVR4nKT9Z9ht2VUeiL5jzLTC3vtL5zu5zjmVc5VKqSQkAUIESRbRYMw1Brsd4HHb0L7uxtltY3fTxjY0joBNNJgkkBCYqIRQRKWSyipVjufUyed8YYcVZhjj/qhSqSTUxvfe8Ws/a44159rPetec70hzEv/tfyArwajQgonHzj5I4QugsB5kERVsYQgloa5BCfMOvkE28AYYQARuUBK8QYrQhGSQCS6DMowFOQghA2zBBmOHSUAc0Ec7mTljShpiHAEDo+iXGCOaDRxex0rQrNXS1TndNbnylhvykaNPv/qNr3jvL73nus2N3/nZh3/xnWcOn8CnHsXzUt9x3Wu+4ptOycNPfvQDhvTwpiyX8fitt2zfdMfbf/pjDz30NBhoj+DgIXQt1g6jFZgKDGRFaOvG9wOwXOCswDOcwluYHmbAeLndaA66/WuuweWz5eEPfBJm+5rbj/vKlvWTr/yyax/5xNMP/tpHgLMAAAckvES2jtp//8PfcMO1b728y1132hY5f+HS06c/9Be+81u+4Q3/7JGd4fm7fvH+n7rn9rtp+dgf/cFv/fN/+BM33n3s3lff8h//1nueKC/tDM0p871/7U0/9H2/N+L/a/k7v/6fntlZf+7iam8Y2WhJmYzurXotuW2n+8t501aOqRsyGay3ze7+0FZ+WqNoWSUL5yWrIWFDOQ4MAmk2bGzoxjirapRxEGF2ETEJK3EXIaEx3kvMpkjwTBJHSakQAd61hlQMqW+vXp670G5Nq6y8O19WdW2UvEFJI5yDCgQk5EsyXoW0GCZBLmDjnMGy7yZNizIkSBarmUspbcWRjOFQcuxLqtt6ZzEwsFbZMqyWvbbTilVFCaTIEGDVrQ5shm7RFWZDXjIJGWcwDKM1FqkT0lQI4tiQkaTOluwMlImTJM0RCShwPucsah1AUEIRFIW1GJbcWJJU+uJna7EbYRm5AEAsMIIYMZliuQ92CAZQEAEMAvoOUOiIVKZN5Njvn90BK6oD7tg9ysjn5mhgsfPa26/Ju8vjk6VtZ4X8zkiDW1uUmOv6Mx97Fqcdzj+Dlx06ddvhm474401cN/XJ66+58Mzj2GjbI3js/vnv/MxvX/qDX/4scAg4AFyLg9NTX/WaZx58Ag8+inIaGIEVACCgXcMqAzsv4hToAAL0C0B4/AR+4se+sa/v6lfNRnvyyHYzaYMMOpidrVb+9Cv+8kfODwC+9ts3Dk+2x7P6a7/x+D234ebbTp25Klev7i1jzEVhJt75M6d3u24oCf8D0uLgtbjt5bBrOHYMUFiLIqiJVDUW9uSM1r7KzBGcojqFMbbvltOJyykmRVtN8ijEVKQEXy2HgTSmkq1vsxbn7BRaY67z/dbo0ZNH6ynZevrphy49+6nncOQAeAuTGgeOoUsYrmC7dS64PHjtRSkOMZV0aNqMKaYUbRNKShbSr/pqtp5FTEbtHEQzigtWgVgKFSUFWb/qFhuzqoyxK9nAMwyInSMhYrIpDgKywS2HzNCJMyixSxTaKpdExLkUQ0xql/3+5looMceilgOTGYsaQ0OM3leI3Zhjn22wtWVhTepDyt4ZqEiSEiVjJKQcKrWUIgf2VYHkLCwgSBmHMPOxFB0yQo0xo7JIBWShAAtUAY8cYVqQgXMoCUQYVvCFg7/5+s0W49WL515/y4HtSw/d9vLbrz9VL/cmH73v0u++/VeffvQ5KnTXrQc/8Z6PHD3avPHN90zXthZd2Lty3lJ65snnnnrsocnW4UOnXr//5H2feuSZ58Hhmumpaw+rCaefOBe7HYCP33aSXLO/m3IpX/2n3/Tdf/+bbjxYfFeffuJ0ufKQ6QdQb6z0cSzT+v7feexdP/UHmEL2cX4Xf+9fvfprvudvArnsrc4+/OzF+z/62MNnPv1Qd3n34tv+xrenrvv0++9/8L4nnziN/Zd+HBW7ai3t7b544fD1p77+bW+75a6NYzecfOKDH/9PP/Irl/d3Ng9OzKTRiKefvARg49jmsJv7bg6gbutjN1z/ylffdvN1zUP33f+JDz361IURgAGum+GJ+Quf4qTFV9w7uf74zRtH17mV1UBv/4n3nu2xds/b7v3qty7XbrmY/WK4lLKHNbXNJaeorGQWy2VT+UIux0hGg3dWk+ROyBcqzriuy7O6Uh2EtCgMG4jkQj7YVbeYNJXEsUfOxZI6EbGMASD4HCPYcGX71chV5SU6iotOqyZoycRWoSggNkO3XN+sh1UPZmtDioWssxZjSs54LkOUEjM74xglUBZiZV9iss5lGYckKNpWTnLXR6omVcmlEIkUYqOrEWmkmdWUAcB69Bk2wBCGDs5iXIEF2UAdIHAGKcM1GEdYB0lgQspwQOywiFhvAYIQiFAUAgxLbFWYL5EK3ARJQAaGwAZFkSKMQePQC9iCIvKAQdB6GINUXujNOcz3sNGi6+AdVgRTQwtqD2IIYewAwbTBSJCECWNYojAaB2NAhFIAhgqGDrMKMYINioF1KAIViMBblIiU0dQYRnQRbQ0VCMM7QBEFoYYIgsdqBe8+y6Y8SoF1GCLaFgpMAqIgC8iAGJaRM2oPjBgjcuH1qawimCw052IUEDR1DXbLmGcTJ2lcrJZQDoFjFLZshKpQqUbDlNJYyEY1IrQ2WRcZVSTnrKREZJnBiEWdGouxDk4hERQzrVchjZm0ZGslg5CYMEoZtFRsKUpFxQUVpZjFsI0xJzi2JnWLyriS5oOWVDwkAIacKnmIRRoQR0xbJEHM7dRxXi1W2beh5MLGCpSVVZDHfjo1fR8V7CjkAjUWUNWCJMFJP/TIDq5FycAI62EqdEs4C0sohH40a75c2UWfsTGDAZRRCqwFGIsVZg5jxO4+mimsw2QCZwEH2ib87R/EaoD0sB7jClUD7zHOTcnFWKg68kJc4gjr1iTs54uoCDI19dbRQlfkam8IkS1pO+i+T6AMtz7BZDlegvQojp2TkkAe1qPvEZwnjVIAcWFiR00pa2Ws0GgTxqtQD5lis8VGAyJImU7MW2Z738KfuHu7mvd2cf/+D//ob35mgWaK3at47qX8eHv7W//0V95+Tbu4cH89f/qRh3Yv7OIN3/ZN73znB3YuLq8mE+061rawcQ1NgjW1iDfWEWvMcHDqfLpwBewxCISgBtuKyXjjUX9bY155aHXvNf0R437+42f+/TsemZ9+Yab+2m9+y1f8qS9/+vQz//r/+nH0BSBgBuy/lH3cMMMP/bu/fPTkyyvIIvoLy8effPDjr/uqb37tvX/9RZ3/9Kv/9C9905dfOnf2yacfOH/m2dnhm3NZ/Movn/7JH/9lfL5867d+xS/90nv/R1jPF8h3/1/ft98ePV/Sft9wqEosqnp1f7/ygbgga4Bd5RTqpmraZVwZVx/hatjfSSYva9tyZSISy2LojHJwxla4ouTJWWvXrOvn8wg2lek1awGTj2SqqhlYe7V2xHG1Eveea2MXnYfZdM0aVrsyLrUOaBrjBhaBbeow73Mci2clzWI0JbLBH5hW88XcpAjlQvBk2Fow7a+6Jrit7C5z1+uQsj/UrAWiHVn1YioY4hwk7QjEN+vTiev6S8ulpeJhIcQWQ0pkeRiSBzkrUIoigVwwZpCyvxqCcVvwV2me0ljQrk8ndcKVYZGZOZNlbRJdQRTA1JObtX1o/yx09KiJUKQUIja29B28aQhdyq7y0xBkoGWKICY2KUZQhCjIQBSTifXIu0uwwDMKA4WAesJwQravJG8r3LIb55E3jvLBLS11E/OBdrzmyPZiuEJjYTErRTFmZzA52uuO8e898NzOZ57C8aO47tQNh+pj6+GGqb/Rz8bLp6+/beuZM+m9P/B33/vep4BLn0XNFo7dcfLl17/6a1/F6D/8Gx+/fOHikDKYuKCt+eRmPHJywsqTLj30eHz4Iw8BVz577xcxAAC86Qb81X/87Sdf9Zabjry8Cdkzy+guXlo8e/7Jn/jh7/uJtz/zomYARqACXnXztGgehhCd3UePMSxoOZbYXwLy/wD0eQNfci/cO
},
"metadata": {}
},
{
2 years ago
"output_type": "stream",
"name": "stdout",
"text": [
"CPU times: user 55 s, sys: 771 ms, total: 55.8 s\n",
"Wall time: 1min 3s\n"
]
}
2 years ago
],
"source": [
"%%time\n",
"\n",
2 years ago
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
"progressive_outputs = True #@param {type:\"boolean\"}\n",
"seamless = True #@param {type:\"boolean\"}\n",
"grid_size = 4 #@param {type:\"integer\"}\n",
"temperature = 2 #@param {type:\"slider\", min:0.01, max:3, step:0.01}\n",
"supercondition_factor = 16 #@param {type:\"number\"}\n",
"top_k = 128 #@param {type:\"integer\"}\n",
"\n",
"image_stream = model.generate_image_stream(\n",
" text=text,\n",
2 years ago
" seed=-1,\n",
" grid_size=grid_size,\n",
" progressive_outputs=progressive_outputs,\n",
" is_seamless=seamless,\n",
" temperature=temperature,\n",
" top_k=int(top_k),\n",
" supercondition_factor=float(supercondition_factor)\n",
")\n",
"\n",
2 years ago
"is_first = True\n",
"for image in image_stream:\n",
2 years ago
" display_image = display if is_first else update_display\n",
" display_image(image, display_id=1)\n",
" is_first = False"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2 years ago
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2 years ago
"name": "min-dalle",
"provenance": [],
"include_colab_link": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3.9.13 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.13"
},
"vscode": {
"interpreter": {
"hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}