You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

242 lines
3.9 MiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2 years ago
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2 years ago
"### Install"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "code",
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ix_xt4X1_6F4",
"outputId": "7b0178cf-9818-40f9-f7ab-7d06a5f8ae5b"
},
"outputs": [
{
"name": "stdout",
2 years ago
"output_type": "stream",
"text": [
"Mon Jul 11 16:30:52 2022 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 36C P0 27W / 250W | 0MiB / 16280MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting min-dalle==0.3.12\n",
" Downloading min-dalle-0.3.12.tar.gz (10 kB)\n",
"Requirement already satisfied: torch>=1.11 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.12) (1.11.0+cu113)\n",
"Requirement already satisfied: typing_extensions>=4.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.12) (4.1.1)\n",
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.12) (1.21.6)\n",
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.12) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle==0.3.12) (2.23.0)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.12) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.12) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.12) (2022.6.15)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle==0.3.12) (1.24.3)\n",
"Building wheels for collected packages: min-dalle\n",
" Building wheel for min-dalle (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for min-dalle: filename=min_dalle-0.3.12-py3-none-any.whl size=10673 sha256=c16cf367ea607357b8db7735f8859001d41a7a0b013224ac38cc730d7edb8c63\n",
" Stored in directory: /root/.cache/pip/wheels/2c/15/26/963d2a412cc59dbf07a77ac814ab5c8eb4e3a892415967af0d\n",
"Successfully built min-dalle\n",
"Installing collected packages: min-dalle\n",
"Successfully installed min-dalle-0.3.12\n"
]
}
],
"source": [
2 years ago
"! nvidia-smi\n",
2 years ago
"! pip install min-dalle"
]
},
{
"cell_type": "markdown",
2 years ago
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model\n",
"`float32` is faster than `float16` but uses more GPU memory. Change the `grid_size` to 3 or less if using `float32`."
2 years ago
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2 years ago
},
"id": "8W-L2ICFGFup",
"outputId": "e382228a-50f9-4a01-d561-c232428cf940"
},
"outputs": [
{
"name": "stdout",
2 years ago
"output_type": "stream",
"text": [
"downloading tokenizer params\n",
"intializing TextTokenizer\n",
"downloading encoder params\n",
"initializing DalleBartEncoder\n",
"downloading decoder params\n",
"initializing DalleBartDecoder\n",
"downloading detokenizer params\n",
"initializing VQGanDetokenizer\n"
]
}
2 years ago
],
"source": [
"dtype = \"float16\" #@param [\"float32\", \"float16\", \"bfloat16\"]\n",
"from IPython.display import display, update_display\n",
2 years ago
"import torch\n",
"from min_dalle import MinDalle\n",
2 years ago
"\n",
2 years ago
"model = MinDalle(\n",
" dtype=getattr(torch, dtype),\n",
2 years ago
" is_mega=True, \n",
" is_reusable=True\n",
")"
]
},
{
"cell_type": "markdown",
2 years ago
"metadata": {
2 years ago
"id": "c52TV1GbBNgS"
},
"source": [
2 years ago
"### Generate Images\n",
"\n",
"- `grid_size` Size of the image grid. Reduce this if you run out of GPU memory.\n",
"\n",
2 years ago
"- `intermediate_outputs` Whether to show intermediate output. Adds a small delay and increases memory usage.\n",
"\n",
2 years ago
"- `supercondition_factor` Higher values result in better agreement with the text but a narrower variety of generated images\n",
"\n",
"- `top_k` Each image token is sampled from the top $k$ most probable tokens"
2 years ago
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 995
2 years ago
},
"id": "nQ0UG05dA4p2",
"outputId": "4af26983-0b57-4437-faf7-7bd97fa579d7"
},
"outputs": [
{
"data": {
2 years ago
"image/png": "iVBORw0KGgoAAAANSUhEUgAABQAAAAUACAIAAACXhmigAAEAAElEQVR4nJy9Zbxt2VXm/Z+6bMuRa3XLq1JecQ9xQiAQCAQNvI00bm/T3QRt7A0SrNPdQHAIBNcmpEkIMQJxr6qkJOV+9ciWtda08X44NwL0L9A9vp011557nd+e65lzjPGMZyj++LXcF1gbTKFyDEuWC7oKByWzSvgaVYiZjSmLgdqxWHFoi7O7dBX9gmIQjff4CmtwgVMLzJzpab3xCPvbX7X19lMPA6B/4BfL07+Y4R568JsM+/QJJ4wDTkgF40iRGJh2xES0nAzIGVq4/X38+s9zYF/6nXz5d6ET9KxHlKUP5JHGcXpB06ACIeAEZxlG6pow0imSYv8UY0R6mmOkFfWEsdBtsMgoS9SogjYMA7aQRyqLAqXJiQRaUwacJa0pQjZ4R5+xin5g3pET40jUGI816AJQTViuEM28JYzsLtmYkJdERW0YM76GCJZc2DvLeXP2lpRCrugzdYVT5BHr2dtDhFnHkEmZztEvWAkTh1JkkEwCDP0eF22zs8IaFgpbI4VaURJNw2pJSmzOyYX9BY2nRJaRzY5YcDVpRNfkRL/g2JTFPiiiZV1Qhs4TA8qwv0fTcHyT04FhZKpYLUjgHAbEIhEsy4TPbE3Y7yFjW4JgFN6gC7piscB7uooRQqCzDGv6wqwiJ7RHMlHhK8YlE8cYqTyDAgfgFf2A1eRASVQTFontKf0KEbJQW8jETEnYlhJxHeH8533G0Xvu7v06Be+Mr0pahihdY/p+DDmXIiZba3TdmCK6YEtMQSQUWylvSzI2FOdE+RwGpUoU3adIMqZoXYaRaCqDACZmKViFXvW7m1NfUk7KlILNdQjiqpC116qEcW2sWccEhj7MnYnDOleilTRN1w9jKUqrUkoZy6B0kRgZBirDumfiWS21McYZUart6qlPZr13xMRrztu84riZbsj48Iff8aaHXvM/38i/MNMd/qIvuu7Jn/VMvfvg0UvOn/lI1pRuZ2/vmmd/zYu/8NX3v/8n/9lHXvun1zz/+V9/+9tPvfaD7zr/oosuu/KKKy7/zEOHPgN40w3/Y3P5gN248MxDfexPrBYP+8PLax556W/+1F/+xH+7+19++yfs6V917Wc+5vPG3eGCy47IZj2u1y/74f81uuWv/OaLn/qkL5m5wTNOsAWzs9ijnHroweVrX/POH/m+3wG+/KVf/vwv+z45e+vg9N49d+0/fPJIN5x/6IJLLm3O7i0/dstb7/7g+17xO7uoJ/KoL+TaKyH7zrrZoVJyLroyOF+G1eC9R1Ll/XK5njZeJIpCeaeVKVlnQWs1LJfzzksuWauUIFZF6CYEZY3YEFaCUlaFMVfKVmSRsB9K09qckrF1zAVtctbr9c7RQ824jjgTAp2frPqx6xiL0sroEkqKxbj1kEzSc2tqH/pSuqZRmD4ISudi0W69OjNtm365EquX0dauzWFwjY1ix0jt7bDsoxRi1F1d8mqjaUyZnnfdzon7ehPs2trK+/VyV0qZz/wwjiEXa6xK4o2tG78aQuUqLWm1Ds10ulzE7Umt4lp0GaKpnFaKIirmHEJRRVsCOoai2rYa10kpb4wtlCTJGR3GtTG2ILEYVWhMSaKCqaRkZXUh931wWmvFMK4rK6WUoYgRrQWKst6EjMMaUqBGm1jAaNc0se+tcSiXYiIW0OrQhuz02AZ2ODwlzphXaJkePTI5b+vMiXV46fdz4jXXfL77jEMf+/pv/OYLDl/dHb3g4QfWWvXTY1t+3px56OSZe08YvZodOjydbm60jVTjfXfed8/dd37ohvs+8oGPvO1195/a4cmXc9EWf/Re4qes52e/gO94xdfu3fD+F37h8/ZPxYuOP6aUjSxemz7mtVbzid4OTK97zHfc/uH3Qfhnr8P5l/O8Zx/6d9/xpPlV0w/fNf+G637tE0O/+It82Tf8yG131GHvzGzj6lXSpx66RUugCRdd/9y3vuo93/2ff+afzfa93zOrZ53Ss9MP6f1Rti8+cvmjnthW59/woQc/9zM3NsP05b/8rouf8bjV/NBD+SmT7Q1ir82UKvfqaN4/ubtuVmF3RJ1Zzydze+tHbuK2e1mv6Bq2tkkn2dQMhTiwGHEN0enWl+TpRyYN2aAV+0t8hSn0gbrSOVon3mpjVYophWKdMYb1cpi1dUo55lRSqazXSluvUqGIjmmwWimtxxhKVDoaq9KoYl03oxTJLiWLK6RAiDhNyJARTclUHj2iG8ZNnnc9d6wINb5Da9Zn0JAHlKAENLEgitoyrLGeMjKM1B2DQhlkiVOIg4jWxIgkgjBYatBLnMN6BILBWJxmXDJrGQM50XQsemYtMUCNBu3ImU5Res7eR8nou4+ff+jk7fennXs+6/M/96Gjzzt9+uaHz9ijFx4+8Q3fzWdcOZ1fFHRz6eXHpxObx9G4Wc0ZZ6MzjKsTF1107Hd//R2sr+K8x3H5FCtUnlwYdmkmVIm9nsaRRqzFQiz0Aa2QgAcDSoEhR5JGaZY7bHmGngzFozzZYDNhSV2zd5rKYWuiJyV0wBq0Y1yiMwqUoo9EQRckMatIgEYKyVF7Vmc4NGNvQe3pIXsE2gyKoigRDSUjQhS0QUeypjZQEMWYURXOszjN9pTFEmcZNdlTIk1BG5Qj9hhNyShFiChHGimFrkYSoilQDMrS77FRsx7QBlWTDBl8ggKGFMmJ1tNnlEYVcqQYaoWGWBBBLNrR73Bozu4SZ5GKaMiRRijgK0JPKRjBOBY9bU2/okBnQcCQC9lgPcMOGzX9iDZEQzQoRVNIGW0JK0SoWvYCtUMGxJIUThMTCoaeSUe/pOkIm+2PfeX65gXJ4BoqWOxhNAQk4wxKMRZshVf0Pb6ljORMNtianGHAO7CURAz4mhxImVHhI0owGuNAMRZUhdcs9phXxEAfmc7YH2gsacDWhIyGcaDdREbMhH5gy7FaMnz4ae3DJ9/2M9/+FYcf/5Tjs/X+b/7iRx/5re996Tf/3N57Hvzsn/+PX/zFT0qnbt5orRFpzaptbOvbkhskmMouVbnjrnuuffx5N5/4yHvuves3nvoK2INj4GAJOwBXf8n5168e+LPXfQpeztk+9sIf/r7X/tArsSNf9xxuvIsHB441LMOhr3jRC55/6P7b9h+3/ztzvbrWn3jBZ1/Wq7On33/6/33OQ2/4NEecj9v/uv2Vf/W3qjr8jN51tSxz6bxarNm9+IpD933gzb/9Y78b7novwKRh2f+rs13o+MKv0N/9I99/3vSi0l5tnGQj3t4a7nr5H33L3e99A78B46fcf/MvbOynePETrzz6Gd+R+FzDRNEkskZlQiarHIdxGVWf0mIMy739pReKboKkmFLVTm6644HTg+/71WOumynM4szNt3/gH0/ff8v6xB3vewcfuvmTm5m5cJIfvooveQVXPAJ3CDmDzZQa1RMKpjCMKMFrZGSdSQ6vSUu8w3uCUAqisZ71abqKEDEVAVRHiDSZIBjPuERD1mhPzqBB0JYUUYI1pEjdkJbYmvWm5ewEVviByQbrNdpx9ChozMB6zbHDRCEk6kwpuJb1km6Gqdk+jxIomj7jPcOCtqa2rIWtI9iR/eGIyfXbTz1wsJS2Zxd+wZffNNxFM8VqtGLh2J5hHGd2caDBGYaETojBClmxMdIdZXWSaYKL4F7A3dNEOcGWZZkInrbDe3YUqdBNmdcsR/wUbzAWn2gVe0vaihLJgbMPM2m5wHN/otkgZVSFhcoTE3XN/ppDx0gBGXCK1YhXjJHDG8SRZaGbcDbR1hgDGltQoC26pYx4mHRMOvrAusc5csQV2opxhW+47BLWmWVErUgWazFCAG/p12xtMgZsh8momo2KMdIvaSas18wrakPOmIYLz+POB0kNLqIqUsZpYmE2ZblmY5s1qAlhh8kRmoaYGXuUI1m6FjKpYDyXXMjUctsDbFaIOxcC0BW1Z1Gop/SaagIZ7dls6Ef21mxuMAwc28AZzjyEP8SlF3P/fTDBFJwlCU4IsD1hlhkSCL7GQXI0HqMpawZFbZlWOCFFuhmXX
"text/plain": [
"<PIL.Image.Image image mode=RGB size=1280x1280 at 0x7F6C72852E50>"
2 years ago
]
},
2 years ago
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
2 years ago
"output_type": "stream",
"text": [
"CPU times: user 1min 22s, sys: 2.91 s, total: 1min 25s\n",
"Wall time: 1min 28s\n"
]
}
2 years ago
],
"source": [
"%%time\n",
"\n",
2 years ago
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
"progressive_outputs = True #@param {type:\"boolean\"}\n",
"grid_size = 5 #@param {type:\"integer\"}\n",
"temperature = 2 #@param {type:\"slider\", min:0.01, max:3, step:0.01}\n",
"supercondition_factor = 16 #@param {type:\"number\"}\n",
"top_k = 256 #@param {type:\"integer\"}\n",
"log2_mid_count = 3 if progressive_outputs else 0\n",
"\n",
"image_stream = model.generate_image_stream(\n",
" text=text,\n",
2 years ago
" seed=-1,\n",
" grid_size=grid_size,\n",
2 years ago
" log2_mid_count=log2_mid_count,\n",
" temperature=temperature,\n",
" top_k=int(top_k),\n",
" supercondition_factor=float(supercondition_factor)\n",
")\n",
"\n",
2 years ago
"is_first = True\n",
"for image in image_stream:\n",
2 years ago
" display_image = display if is_first else update_display\n",
" display_image(image, display_id=1)\n",
" is_first = False"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2 years ago
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2 years ago
"include_colab_link": true,
2 years ago
"name": "min-dalle",
2 years ago
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2 years ago
}