min-dalle-test/min_dalle.ipynb

227 lines
1.4 MiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-05 15:53:27 +00:00
"colab_type": "text",
"id": "view-in-github"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
2022-07-01 22:16:55 +00:00
"### Install"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-05 13:37:52 +00:00
"execution_count": null,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-04 20:30:39 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
2022-07-05 03:17:31 +00:00
},
"id": "ix_xt4X1_6F4",
2022-07-05 09:52:30 +00:00
"outputId": "20acdc7e-e2e5-4d27-95cc-8b34f069fdcf"
2022-06-28 00:58:17 +00:00
},
2022-07-04 20:30:39 +00:00
"outputs": [
{
2022-07-05 09:52:30 +00:00
"name": "stdout",
"output_type": "stream",
2022-07-04 20:30:39 +00:00
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
2022-07-05 09:52:30 +00:00
"Collecting min-dalle\n",
" Downloading min-dalle-0.2.27.tar.gz (11 kB)\n",
2022-07-05 01:42:27 +00:00
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.11.0+cu113)\n",
2022-07-05 09:52:30 +00:00
"Requirement already satisfied: typing_extensions>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (4.1.1)\n",
2022-07-05 01:42:27 +00:00
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (1.21.6)\n",
2022-07-05 09:52:30 +00:00
"Requirement already satisfied: pillow>=7.1 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (7.1.2)\n",
"Requirement already satisfied: requests>=2.23 in /usr/local/lib/python3.7/dist-packages (from min-dalle) (2.23.0)\n",
2022-07-05 02:52:00 +00:00
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2022.6.15)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (2.10)\n",
2022-07-05 01:42:27 +00:00
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (3.0.4)\n",
2022-07-05 09:52:30 +00:00
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.23->min-dalle) (1.24.3)\n",
"Building wheels for collected packages: min-dalle\n",
" Building wheel for min-dalle (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for min-dalle: filename=min_dalle-0.2.27-py3-none-any.whl size=11247 sha256=788e2010694dd0085a6352e07de88d108d171e4f04457575a45e5b02c7914822\n",
" Stored in directory: /root/.cache/pip/wheels/da/00/12/9761a7506eb7e3e1f4ca383cd37e426dd146eea5bc5cb016f3\n",
"Successfully built min-dalle\n",
"Installing collected packages: min-dalle\n",
"Successfully installed min-dalle-0.2.27\n",
"Tue Jul 5 09:43:32 2022 \n",
2022-07-04 20:30:39 +00:00
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
2022-07-05 09:52:30 +00:00
"| N/A 34C P0 27W / 250W | 0MiB / 16280MiB | 0% Default |\n",
2022-07-04 20:30:39 +00:00
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-07-03 15:17:37 +00:00
"! pip install min-dalle\n",
"! nvidia-smi"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-07-05 13:42:27 +00:00
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-05 13:42:27 +00:00
"outputId": "6c67dde2-3f66-4461-95ec-cb6b75f36b44"
},
"outputs": [
{
2022-07-05 13:42:27 +00:00
"name": "stdout",
2022-07-05 15:53:27 +00:00
"output_type": "stream",
"text": [
"initializing MinDalle\n",
2022-07-01 21:34:23 +00:00
"intializing TextTokenizer\n",
"initializing DalleBartEncoder\n",
"initializing DalleBartDecoder\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
2022-07-05 15:53:27 +00:00
"from IPython.display import display, update_display\n",
2022-07-04 21:27:02 +00:00
"from math import log2\n",
"from min_dalle import MinDalle\n",
2022-06-30 15:25:24 +00:00
"\n",
"model = MinDalle(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
2022-07-02 18:54:18 +00:00
"### Generate Images\n",
2022-07-05 01:42:27 +00:00
"\n",
2022-07-05 13:37:52 +00:00
"- `grid_size` Size of the image grid. 3x3 works best when displaying intermediate outputs. 4x4 has been tested to work on T4 and P100 with intermediate outputs off\n",
2022-07-05 03:01:13 +00:00
"\n",
2022-07-05 13:37:52 +00:00
"- `intermediate_outputs` Whether to show intermediate output. Adds a small delay and increases memory usage.\n",
2022-07-05 03:01:13 +00:00
"\n",
"- `supercondition_factor` Higher values result in better agreement with the text but a narrower variety of generated images"
2022-06-28 15:05:59 +00:00
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
2022-07-05 13:42:27 +00:00
"execution_count": 2,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-05 00:56:45 +00:00
"height": 819
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-05 13:42:27 +00:00
"outputId": "7ebc8342-20c0-4b04-9f5a-45dcce913db2"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
"data": {
2022-07-05 15:53:27 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwAAAAMACAIAAAAc45fZAAEAAElEQVR4nEz9Wax0W7seBj1vN8acVWt93977b/yf/5zfTew4sZFNAiGdkmOimMQ+OISEGwhWuEjCBQKkXBAkMBAhJHKNkBAiUgQKookCBKE0hIg+EQiiJHbsxL1Pf/5m769ZVXOOMd6Gi6rtUFpaWiXVqlU1V805nvE+Hf21f+RvPs+jpoJWEY2cMceltwQPkki1gkepUDC7F1Vq5ojMJkYWno3K2aO0iDPhiUigCHHAFLGQDn0FbwigCaQjACTWgm3YLpCO3pELbLi8AgxqQMAa0IGJL644EwLYBSEAQ4EiCMEdamDCGbAGA2IiFJuABUGIBIC50BgsCMAESSBDnqDCCmwNQlgJJ9QAKUwhBTggEIIw1LB1sOGcsB0vGz6/oTOko05AMAoquE+wgQ0ovH3C5RXjRGsgwwpMhgGhOD7DDPeJdw3jBAs8UAADx4FLQyWaYSywIhnpsEJjjInjRL9g33Ab2AwzsW2gxMc3fPUOzrgd6AIfcMJLQwag4EQASZhv+LLh8xuYoTsWEAQlTEczjM9gIAmlqAIm9o5QUCEmVDEnALDiOMCBpiAggCgASML4jN1QDgDBWEAwNAECMWIiE1vHfaIZYoIBKISwJhIwhS+QYZ6owNbBhBkoQhbEMN9wNdzvyAIrPFEEDUDhBE5UYi0kIQsViEQWSOEBBCIhCl9ohuOd/ef/zvWXP6EEKhDguEMAFOAQRgIFRMH4+SsVQCAFTmAGL7CAG3yqsvtEBCKZwbES3DaLDJBkhhCls8/jcpUxZt9aJBP1qvTKjGiNYh5FGSSYgIiIs0qgKQVLTo9EwJMSWAcYbTMQCoRKAhHpOd62qzGVZ5QzJxdpb0HClWiCrDx9rUh4vRpReVK3zlm1oqiyiopszLd3rz3GDAZcKBgwa55MzMLlLHTO5ZU110Uo1wzqbTcACVR6OZjafX6+7gbQjJToSnK69lZFi4vgU7Wdcy0vn/na2X1NbPsuXjWdCevtNsS2iEAjRMAUM7EIXGBCHhCCD4zEdkGcIAE5JiAK63CBACwghTRUQQWaiEQySCCALzCBBOVgAgpQzIUCiCBAAnMhAGU4AQEGZkAEESBGJmaAFZmowmZYDi+gcAaEkA4WpAMFFyyCKCwAQAzzQAAMsGACcOQADJ3gjiBoYRaqIT7ACgGI4ARCoAJORCAJ6wQ5vCE20AItLMUgVKEcysiFEswDl1d8fP1v/NP/yfvd/fOIbkI417FWvr6Yr0WtnwFyUiISjDE3NaGYhBWkZOlLxZNYuFGGMAXo8AVwrvVqFBXUemWBJBNM5I5jHO9f25yLTMOFgkVlUZSXCVGt41xpIrSLEq+bF+YyURZUFgVwP2YVh6+t11xIpiqBc1WV0Frruhl8nDPkevn02bvaLrMqpjfRAgqQiQwvLU2/v7zamgniIktHVGljd28q7KeHL9bbrbq0TUdxLW9bh6dHaS4XJlQlEVGiKKhAwmCqrMelNJxQxADAJAAVY8WSImGwUGRWEbNkJSO23nytFUWia+VuxhkQnl6qZCqPC60ndTGf95fX7itKuJwyWYvaxsmg5E0TlWN5Eil4E01f3CyRRAAg4AICeoz7l69bLg9FBSu0vLRFiAorfPam7nN4EepCjJjcenCWB7NEFgoRMub5eiX3ILNw4tQohJZPbyKUa9Q6V3XbOJbQcCjRxhSEzMJtnVms1TQmZDoxgZBEBQ8i0ft5++La4zyW0BilsFxTzZ2bL+4cVHmcZ7DWKl5LhIN1rnJfrIgxWLWImKu3Rr/wn/hbZBQQ7aBbxWfNjnhPtMD33Xh1ubsSDyzmds+cON8FRvnHLZkvP1rv17h9fX2DbxxcmOmMagMTRJrlMXAJ5IvQ+3LOnoCgDHQDG8TQN/QvcCqscFmgF8BAABPUYIa+oQlSwYBtEEMkiMECFHhCCyfQX3B5xdcfwAVJECEEBASjdUSCC/XAEI7lsA4TlKAAFqijgGigwPa4oDC6oivmRDJeOyAoQSa+8wXOhcZYC8XAAAJsGIV7oTcQ4zgxA73B/YnVFkCAbViBCGSABD4RDgZKsO2ogXVgOcKfAA4CETAjFrCAgAfaBawQxlxQxQRoQApg0I4v3+HjzzADmgDgBSV4wBp8QgBOeEEIrCBgJlZADQHIG2IidrxcMQrzjkoQIIUoMKEIewcRjhtqIQgFVCACYlgTQmCAgXAUQxtWoRJFSIBPSOC20F4gDfNAOrJAiQKawhN9hzDOA1gIAEAVkEjBmJACF4QQDhJAMRwIEMMBDUwHFoigHR8+ouJ5eIPhE2oAQApK0MAf+Yfws5+ppqvhbeDSWKrG0IqoUiJkEfGKAPBiduZy9sr2Zd9prs/lxSVeTOsaGFqfCle7bu5HDc0FJEGY6AwnFffsyhet8AoklXS1T2N64qLymvyZ74vmbdm713e70y2PqNrA4OBEaH7IeO0v18x73rUmRYGIkLNKtC8farIJEdVR3mRvajHHiBDIxqzqlusDr6/Rvru9xtefTFLIjXgmMdcZabqd8zSRXThqnfCGy0u/rHnefYnoRgJeLxXfsP8M7TuXl/jmo2lULU2tYpI6fZnua62g2IQqaSAufJkRKForOlsnJh17rZ8gP8T+3fY6PnwQISNn1AATtfuolKLwRRACigaRgFX79DjGoSzEmHkCiloAgztoIQnNMJz2qyatZFw6IiGPaz4QDgGa4TagDbvgPiGKLAihAscD0AR8wAtEkAZ3JIEJ50KTJ8KmwihsGyIxBwAwwxPuCAAMHyACMfwEEqWQC9ZABYhBglpGvOAgA7/g0wdIoE6QoAjC8AP2gjMgEzgRhHTIBjTcj+emJXfwDfmGeoW+IhPHT5EC7kAhgTEwJ6yBFJww/0f/yX/M375+pzxAR8zJeKf7S1aR3xDKurlN4G3cTeSr3L+hz5/p1uvyPfmyFf2MbguwJGC1wom8C9T6z+vrj8+PJIvcBASuVVlFK6CMnXOBlxRcvpT9wxx3rl58Jfk4PyvXNOn28j1cPnz6ycEUScZc7iW4zfCqpt3Pszc6Z6YgPC98Scendd/7tiXe4nNr+nVUa1/89vb+x59/ayEzam9SjpJ6W7PbLoUsN/YMoMt590u7GLVvztultZeUz/i0fNxJ9/7uF+jdj99+uigasQjuFVyNId3oNt4KVFFdbGqV1wtfjjwLy4uQ9c72SWsie6mpfJq3AHfhF7HhI7giwMRE1JU4MTA9QaIk/XWXeaxjnE2bCQuoFPcZTZoQq5QyHDXKG/q7bXfPWwxl3VlEi+e6S4zkfd/pXEwx4+zUwBCldJDInL517YJg3PLW6vqqOyLfcDK1TgzJd0GfdH0GXvpVPh8kNeJUEiqw8PIgpgQjomvOVSHFyRfaj8hbjEZ2Sf7Mn0feury/2Lu24hv/CG20iqXEKyhutILsC+wjhscdDm5G5REAdIylqp1yYBy4tXr3BV5qzQ/0FmgySTT68G94fXJ/sRf69Bbkt3Qt1dLlQ7laa3dfuW84F/3oj/7tcTvXWhcVpHxDo6F6yixkF1mylx6YAWjRRJt51xUrcF5K2X60vpMeP758mGGyavLMUC5NLSCaq2dkK6Qgr2BGK7DBEwkkYDv2jn0T/U6cA1eC2vNK0Tqo48KYgCqkQwhQaAMCmWBFPtbFgjaoYdsghTkRjkwUnlv/vsEnPCENm4AZKig85woEeCIdzNh29A11IgZgYIAJpqjHVAC47JAGMVA+d3WZOG8Ao+yJcvIx7Uhogz8WZoc/Bg8AKQoQYCYogMCcKEAatv15NyeiEAABQSACCOkgQhW6IQuyoRwoVIEZMRABMPorGlDAOUHxfMBjyVeBLtwHKlECIaRAFQSg4I4ZqBPLQYb9BUiwoxix4IFImIAUtWCGdMw7YMh4bpRJwI7zBAkiAIAFrKhvH+CEnFgDAKSBCwSAUI5irAkIVED0REJ+AIzI57MV0IA1kYzHLp0AMMpRBAAO5MKa8IAArABhOrgAQhICkEICc6AJAvi7/gF8/VtiFh1I18cEas613
2022-06-28 00:58:17 +00:00
"text/plain": [
2022-07-05 13:42:27 +00:00
"<PIL.Image.Image image mode=RGB size=768x768 at 0x7F8C7B088FD0>"
2022-07-05 15:53:27 +00:00
]
2022-06-28 00:58:17 +00:00
},
2022-07-05 15:53:27 +00:00
"metadata": {},
"output_type": "display_data"
2022-06-29 17:55:23 +00:00
},
{
2022-07-05 13:42:27 +00:00
"name": "stdout",
2022-07-05 15:53:27 +00:00
"output_type": "stream",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-05 13:42:27 +00:00
"CPU times: user 34.2 s, sys: 503 ms, total: 34.7 s\n",
"Wall time: 35.1 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-03 22:40:27 +00:00
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
2022-07-05 13:37:52 +00:00
"intermediate_outputs = True #@param {type:\"boolean\"}\n",
2022-07-04 20:06:14 +00:00
"grid_size = 3 #@param {type:\"integer\"}\n",
2022-07-05 03:29:48 +00:00
"supercondition_factor = 8 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n",
2022-07-05 13:37:52 +00:00
"log2_mid_count = 3 if intermediate_outputs else 0\n",
2022-07-04 20:06:14 +00:00
"\n",
"image_stream = model.generate_image_stream(\n",
2022-07-05 01:25:38 +00:00
" text=text,\n",
2022-07-05 13:37:52 +00:00
" seed=-1,\n",
2022-07-05 01:25:38 +00:00
" grid_size=grid_size,\n",
2022-07-05 13:37:52 +00:00
" log2_mid_count=log2_mid_count,\n",
2022-07-05 01:25:38 +00:00
" log2_supercondition_factor=log2(supercondition_factor)\n",
")\n",
"\n",
2022-07-05 02:52:00 +00:00
"is_first = True\n",
"for image in image_stream:\n",
2022-07-05 02:52:00 +00:00
" display_image = display if is_first else update_display\n",
" display_image(image, display_id=1)\n",
" is_first = False"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-07-05 15:53:27 +00:00
"include_colab_link": true,
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-05 15:53:27 +00:00
"provenance": []
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-05 15:53:27 +00:00
}