min-dalle-test/min_dalle.ipynb

231 lines
170 KiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-07-01 18:19:35 +00:00
"colab_type": "text",
"id": "view-in-github"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
"### Download models"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-01 16:14:31 +00:00
"execution_count": 1,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-01 16:14:31 +00:00
"colab": {
"base_uri": "https://localhost:8080/"
2022-07-01 18:19:35 +00:00
},
"id": "ix_xt4X1_6F4",
"outputId": "8ff09575-fb6d-4c74-8ce5-e5d2aacf4a23"
2022-06-28 00:58:17 +00:00
},
2022-07-01 16:14:31 +00:00
"outputs": [
{
"name": "stdout",
2022-07-01 18:19:35 +00:00
"output_type": "stream",
2022-07-01 16:14:31 +00:00
"text": [
"Cloning into 'min-dalle'...\n",
"remote: Enumerating objects: 511, done.\u001b[K\n",
"remote: Counting objects: 100% (314/314), done.\u001b[K\n",
"remote: Compressing objects: 100% (161/161), done.\u001b[K\n",
"remote: Total 511 (delta 202), reused 239 (delta 149), pack-reused 197\u001b[K\n",
"Receiving objects: 100% (511/511), 1.72 MiB | 8.46 MiB/s, done.\n",
"Resolving deltas: 100% (298/298), done.\n",
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 782k 100 782k 0 0 950k 0 --:--:-- --:--:-- --:--:-- 948k\n",
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 449k 100 449k 0 0 565k 0 --:--:-- --:--:-- --:--:-- 565k\n",
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 267 100 267 0 0 731 0 --:--:-- --:--:-- --:--:-- 729\n",
"100 2117M 100 2117M 0 0 89.0M 0 0:00:23 0:00:23 --:--:-- 80.9M\n",
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 267 100 267 0 0 737 0 --:--:-- --:--:-- --:--:-- 735\n",
"100 2818M 100 2818M 0 0 96.0M 0 0:00:29 0:00:29 --:--:-- 99.8M\n",
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 267 100 267 0 0 719 0 --:--:-- --:--:-- --:--:-- 719\n",
"100 178M 100 178M 0 0 117M 0 0:00:01 0:00:01 --:--:-- 133M\n"
]
},
{
"data": {
2022-07-01 18:19:35 +00:00
"text/plain": []
2022-07-01 16:14:31 +00:00
},
2022-07-01 18:19:35 +00:00
"execution_count": 1,
2022-07-01 16:14:31 +00:00
"metadata": {},
2022-07-01 18:19:35 +00:00
"output_type": "execute_result"
2022-07-01 16:14:31 +00:00
}
],
2022-06-28 00:58:17 +00:00
"source": [
2022-06-29 20:32:51 +00:00
"%%shell\n",
"\n",
2022-07-01 18:19:35 +00:00
"repo_path=\"https://huggingface.co/kuprel/min-dalle/resolve/main\"\n",
"\n",
"mega_path=\"./pretrained/dalle_bart_mega\"\n",
"vqgan_path=\"./pretrained/vqgan\"\n",
"\n",
"mkdir -p ${vqgan_path}\n",
"mkdir -p ${mega_path}\n",
"\n",
"curl ${repo_path}/detoker.pt -L --output ${vqgan_path}/detoker.pt\n",
"curl ${repo_path}/vocab.json -L --output ${mega_path}/vocab.json\n",
"curl ${repo_path}/merges.txt -L --output ${mega_path}/merges.txt\n",
"curl ${repo_path}/encoder.pt -L --output ${mega_path}/encoder.pt\n",
"curl ${repo_path}/decoder.pt -L --output ${mega_path}/decoder.pt"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-06-30 15:25:24 +00:00
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
2022-07-01 16:14:31 +00:00
"outputId": "e441aad1-ff42-4187-d7d9-c9f91ef4e2f5"
},
"outputs": [
{
2022-07-01 16:14:31 +00:00
"name": "stdout",
2022-07-01 18:19:35 +00:00
"output_type": "stream",
"text": [
"initializing MinDalleTorch\n",
"reading files from pretrained/dalle_bart_mega\n",
"initializing DalleBartEncoderTorch\n",
"initializing DalleBartDecoderTorch\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
"import os\n",
"os.chdir('/content/min-dalle')\n",
"from min_dalle.min_dalle_torch import MinDalleTorch\n",
"\n",
2022-07-01 01:23:15 +00:00
"model = MinDalleTorch(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
"### Generate an Image"
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
"execution_count": 4,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-01 01:23:15 +00:00
"height": 511
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
2022-07-01 16:14:31 +00:00
"outputId": "118e6888-4d86-4fb9-be39-8b83d0579754"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-07-01 16:14:31 +00:00
"name": "stdout",
2022-07-01 18:19:35 +00:00
"output_type": "stream",
2022-06-28 00:58:17 +00:00
"text": [
"tokenizing text\n",
2022-07-01 01:23:15 +00:00
"['Ġcctv']\n",
"['Ġof']\n",
"['Ġyoda']\n",
"['Ġrob', 'bing']\n",
"['Ġa']\n",
"['Ġliquor']\n",
"['Ġstore']\n",
"text tokens [0, 17685, 111, 24509, 976, 11811, 58, 13142, 1110, 2]\n",
2022-06-28 00:58:17 +00:00
"encoding text tokens\n",
"sampling image tokens\n",
"detokenizing image\n"
]
},
{
"data": {
2022-07-01 18:19:35 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAEAAElEQVR4nHT9x+9t25Yeho2ZVk47p1/+nXzOzS9U1atMFkEKlmjZsiEDNtx2y/+BG24YtptuWID66hCGIcmgTVimaBar6oV773s3nnzOL++898p5BjfOK4oW7dVcwNprL+Ab8xvzm98YA330YKYb/r17ZwjR3335crdNfL/38RcPMc3+r//lv7AxpBIAwMCHf/Qnv/gf/if/9Oiwc/Hizf/hf/N/3rfZf/q/+J++fXETJ/sfX38JEAOwP/3pP/hP/5P/+IufPsii6n/3v/3Pvn/9al3c2eb5X/7ZL/7wL/7IM/UsKv+r/+pfPPn07Kc/++jf/Dd/+zd//Zv1ZqWbxhc/++Txw4O/+de//OblLwEkAAD0/+lf/dX/6n/9v1Si+T/97//zb19e/MU/+dO//Ks/+82vfvjq69c//8XPGDM2y/nRwZhRIKT88l999eWvvt42awnlzHt4en7/j//kj7uBncflf/3P/h/KgiefPv0n/9Gfzab2xcXb77/99s3Vxdmjc9nSNOOPnzx1XT+vxeWbW6QaTWe9ToeAunxz/d03P9Zl8x//z//D/+z/+J8PptoPr775/Z/rnHz+08+DzvDB+elk0j08PiEsX76f/8t//m9eXrzK8/jBkwcHh6P3b96/+PbtvUdP+uOZowd+P3jy7LHt8G9/83Vn3Pvn/+W/SMuq0+us1vv9Pg66vS/+4CemrnX63Sop4zTWqPHbr7599/L64OD48Hgabvdv375HgO/fP1VSnp6Piixe3Ny6njM9Gl69vfr6ty9yaADg/tmnn//ks7pGr3+4IaZ+fHLoe+7R8UgU1Xdffc9VighBCNuOoTM1Hftn9w78nvV//7/86y9//buijjSdntw7s5mRJ2lV8M5oqBmubbm6rU+GPYXR7cU2K5KmLfM0/vHH7/IqViBboI/OHh0dHXz2xc/9jt3U9ep68ebNxas3d75r51XK24pQMTvrTUddyzLqNOmMur7n8VIdHs0ePH1YtnKxWAspPdfNquzm9vbb71799pffzdfX8P/nOqCs7/u6QR3PbXjtGw7ztc4gmEzH44Peu5c/CIR++a++/uoq/beP/M/+R3+kav3F1Q/f/rChRNdt05YC2YZGTYIc+uSnjy2NFWVuaZSZRK9UXTfM5ohJhpu+52+7Du1xtREvnr8wqD2eDp6/x4oDQOt0telRZzyZvCzf7OXGHAVwcWf7rD/qWgjVcZHtk+Vq/VQ87JqjwO+nVWE67v37J4/uPzo56H9jv/x79EPPNvuDwNVMoTTmkvOPTj/+7PPA6h8dn3SGI11zCNOms15dVcvrK9uU3dGg07dl3AvzfTB1CYMsj87OpoZuMs8oZFlmJZXqeHo+v5sz0xiOJ7bVGfcPdTuoCx7oE9tBu7v88uI1Fw0v5c9/9plneEkRv/jy1Te/+SHPmnudg0n/4d32FSB/dnqCGEaYS2gG0xFAaxCzrFujY7p7jxLCK8kQmU5mZcVbJrMkG97rDfoWz3def/bZp58IrE5O7xGGm4rs1mmZtx2PxvN0B+rqzSKLc40xyijnKJh2x6fT8cHQ8Y20TAxm/+Iv/tA3ddvRXr16FcWhQhiAjQbTg2H+an0BEC/ultH9uNsbPvjoiCMKBilRHdWlZxn3PjpZJXd5Ulg6G3b8B49PBsOOTUmaJrppdPtOvckQJVmSGgONWhpw3qiKITMvU8s1sVRSkRZ40XKdGFLGhmZnVayAShC26wyGAy/QP/v0k7xMj48Pj++d9cdvy7Yo87LNi2gXJndVuVzYHvM6jhKJSMVoMhF1s71YMdM9mcx6w0Fv0M+KIntSfvHR+i/+ePHjN69Xi+1uHX77w7clLP/dAEh5G+62LYBxt+xQS/ejA5iePR7NZkeGQaeze7toNzs5vY2+W8a/f4Qqs2prVBEAoEqKLE9hLju9blOVoqkdW+8EZnS5AAyKENey6mafptv9ahHF+6TMJLS2rbfLzYuXv/voo5+4jk+JbDkAgM4QJ61AwGVrGxqXACAMwi1dNDLcRLtGto7HJU6Flg8OzSefzHRN+0f/5K8Cx0ay9QLj77+LdAfm+KBTQCnqdno8HmNrOhsoighlVRnuNklap0WUmAY9Ohic3zvjXJw/u1e/fLXNb5EU/Z77yRcPHIvZLhkcWlWhRlNbc1Ctit7Au/fozNvshES7NLEl3u3y93fbouCr5d3F6+eYyIPDiQIQoJBE23ynbl8ePRhOD0bMIMnXiWmanqubmsF52wqOlCIYIaTpzAw6AdPJbpfyVWk6KvA7Hd9/8erVQt0NB753/9Af2Eq2QdfjEj56eh9hutokl7eG55mn9w5nR/3lavf1989vL+5ANADUsvxhfxhYRt93HcaiydAx/U+ePer37bLJKxGudzeiVaZjdZzuYpNcrpc1xFn9/u5yZujGqD/mBPZRmlf5BrWi42m6xEgArvvDzsefPv7oowembtZJkRaN4dimb+M9402JJNIIJTrOcNnUQmPc1DQkVSNbhaBtm6blOsWW6QghKmgVAIBCgFUjDU0bDDpjNhBCnJ4WCFuL1W3V1p7t5km8Wc1v7653u3XRFJ3Au7tdLOabw8Pxe3rT1vzRR08+t23jSA9cpxi2o+no/NH9P/rjn8RhtJlvn//48eWb9xdvXiznG8Ub2eKYhzkAADQACS/IDizd2c0jz+54llEmsinAMO1e0MnjMAUwAGzdrItaSAEAtKllDXkUZXnB47gSHAhGnutzpZTEUmDT1oxcq5oqirdpGNdFBRJRnQHEWZ5TDWyDEgwtAABwJBk1FKC2IZIiohMAXdOZ33E7HS8KQ9vuHRzOOoHjOMbpvcM/5n9o6tajhw982472uwdP7t8sfnFz/f5gNjAN3XIsLLBuBZPD47ZBbdtWefju+vXz5y9vr5dxmFBMfvKzTx/c/wMv8F3fOjk/fvHyRwDRiCoYOc+ePjI9rSiqP/rzn4X70HU8w9RdzxvOpmVbp1l5u1zFaajEJgrTKCvv7qL1/LpM0tnBaL/eLleL9XLz6t2bu/WPSnvwj/78L8/PZ5touI+jJE0pY5qpSc4pJYCUlEIoITHn0Cou0iwN95u6zbvDnmOa7y7fC4BH9485nJvaUKHW8RxC2L2Hh1XThkXYiEixYjzzHz077wy9SqaIqM16X8VpVVcIY0KQYeiARH/UdTVzPAwmh4M0i+N8dBod8Er2u0PHDBbbcHBzdbtZAsiyzdu2whQo0Rmrk3Rf5Imu45pAyyVgZDvOeDYbj2cEUEYpviIKGsMydI1WrUJIUIaQopgo3rSiqSVlSgkpJBdN29SirY2O1VYgRKtAACgAzEUTFlFW5GkeBkFHgjBsy+v4+3hrevb0YGYxttuu/OfO65co3G33m2i33b1r3283h4KLbRTOd7uKV1Ee2aY9OBx1u/3BcAZMEy3Pi/LTnz1b3C7S3T7eR5v16vry9ofvfri6uguTbfx7DELcVFd3y10SmYaGGTFMXQBhju10Ush54BABLQZWcQUAlDdcN6x1tMwzkSW1ZjDJpWnrSqpWtKLCjol1XauaqmzK7X5dlCVgOhwMLvVh3TQG0zSGNcKqDwHABUJKKF5Wue94nm3+7JOfTQ8mp6eHjx6eahrhrdoup6NeR2f4YDpjjIECKVQYp1WLXb8zGg63d/Ob6zvg6PWTN4Href5ws0kWd+u723WWiS+/+vry8hXAFuOjn/z8i/PzcyHlbrOt8qrOqqbiALBfx0VWSpAdv9/v0KcPslev3jZtG4f1/Ga1T3eb9W6xWPz43cv5ImsEazkv23q13kNRdHtBJ/AoJkVRzW9X33//PQAkSXR0Oj46PsS65lhmXZZtXbdVZZkGUgqUlEpIIaTkSgoFAhEpobmbXwlo9elUfCA1Hce77Y7qu3DvurZl2hWvyqKosjBwTUJ6gasPOr6tM0PD986OsQIukIYY1XTP8S3TKua5Z1mB6yEsCVaWpvc7naePH4MC2+1iRQ9uV0+Th71VYJru4eGBZrht3WLKLGag3siyNMc1l
2022-06-28 00:58:17 +00:00
"text/plain": [
2022-07-01 16:14:31 +00:00
"<PIL.Image.Image image mode=RGB size=256x256 at 0x7F24258E2110>"
2022-07-01 18:19:35 +00:00
]
2022-06-28 00:58:17 +00:00
},
2022-07-01 18:19:35 +00:00
"metadata": {},
"output_type": "display_data"
2022-06-29 17:55:23 +00:00
},
{
2022-07-01 16:14:31 +00:00
"name": "stdout",
2022-07-01 18:19:35 +00:00
"output_type": "stream",
2022-06-29 17:55:23 +00:00
"text": [
2022-07-01 16:14:31 +00:00
"CPU times: user 7.4 s, sys: 31.8 ms, total: 7.43 s\n",
"Wall time: 7.39 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-01 01:23:15 +00:00
"text = \"cctv of yoda robbing a liquor store\" #@param {type:\"string\"}\n",
2022-07-01 18:19:35 +00:00
"seed = 0 #@param {type:\"integer\"}\n",
2022-06-28 15:05:59 +00:00
"\n",
2022-07-01 01:23:15 +00:00
"display(model.generate_image(text, seed))"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-07-01 18:19:35 +00:00
"include_colab_link": true,
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
2022-07-01 18:19:35 +00:00
"provenance": []
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-07-01 18:19:35 +00:00
}