min-dalle-test/min_dalle.ipynb

182 lines
173 KiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
"### Download models"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-07-01 01:23:15 +00:00
"execution_count": null,
2022-06-28 00:58:17 +00:00
"metadata": {
"cellView": "code",
2022-07-01 01:23:15 +00:00
"id": "ix_xt4X1_6F4"
2022-06-28 00:58:17 +00:00
},
2022-07-01 01:23:15 +00:00
"outputs": [],
2022-06-28 00:58:17 +00:00
"source": [
2022-06-29 20:32:51 +00:00
"%%shell\n",
"\n",
"git clone https://github.com/kuprel/min-dalle\n",
"mkdir -p /content/min-dalle/pretrained/vqgan/\n",
"mkdir -p /content/min-dalle/pretrained/dalle_bart_mega/\n",
"curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab.json -L --output /content/min-dalle/pretrained/dalle_bart_mega/vocab.json\n",
"curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges.txt -L --output /content/min-dalle/pretrained/dalle_bart_mega/merges.txt\n",
"curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder.pt -L --output /content/min-dalle/pretrained/dalle_bart_mega/encoder.pt\n",
"curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder.pt -L --output /content/min-dalle/pretrained/dalle_bart_mega/decoder.pt\n",
"curl https://huggingface.co/kuprel/min-dalle/resolve/main/detoker.pt -L --output /content/min-dalle/pretrained/vqgan/detoker.pt\n"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
2022-06-30 15:25:24 +00:00
"metadata": {
"id": "kViq2dMbGDKt"
},
"source": [
"### Load Model"
2022-06-30 15:25:24 +00:00
]
},
{
"cell_type": "code",
2022-06-30 15:25:24 +00:00
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2022-06-30 15:25:24 +00:00
},
"id": "8W-L2ICFGFup",
"outputId": "aecc5295-ba0f-423f-983e-b726b5f56cbf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"initializing MinDalleTorch\n",
"reading files from pretrained/dalle_bart_mega\n",
"initializing DalleBartEncoderTorch\n",
"initializing DalleBartDecoderTorch\n",
"initializing VQGanDetokenizer\n"
]
}
2022-06-30 15:25:24 +00:00
],
"source": [
"import os\n",
"os.chdir('/content/min-dalle')\n",
"from min_dalle.min_dalle_torch import MinDalleTorch\n",
"\n",
2022-07-01 01:23:15 +00:00
"model = MinDalleTorch(is_mega=True, is_reusable=True)"
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
"### Generate an Image"
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
"execution_count": 4,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2022-07-01 01:23:15 +00:00
"height": 511
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
"outputId": "7773a2c9-70a4-40a9-8fbc-80529ef1f68c"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
2022-06-28 00:58:17 +00:00
"text": [
"tokenizing text\n",
2022-07-01 01:23:15 +00:00
"['Ġcctv']\n",
"['Ġof']\n",
"['Ġyoda']\n",
"['Ġrob', 'bing']\n",
"['Ġa']\n",
"['Ġliquor']\n",
"['Ġstore']\n",
"text tokens [0, 17685, 111, 24509, 976, 11811, 58, 13142, 1110, 2]\n",
2022-06-28 00:58:17 +00:00
"encoding text tokens\n",
"sampling image tokens\n",
"detokenizing image\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAEAAElEQVR4nGz995dlWZYWCO4jrn5aP9Naurm5VqG1ykitKiurSAqoaRoaWAuaYfUCVncPMDSLQVUXUPQ0FEUpqlJEZlZmRmTocI9wLczNzd20ePbsaS2uPufMD55VXcPM+QPOWvfevc/59re/b1/0xJlnF88+G5CU7e2DcCS6dHxBInKlWJ+YGE0mU2999+233/5ZqfsZwABAC2Dw6YvPvPi5p1959dzV927/i//t38cHUlTW7t2574ENgH/ll77zV37ll9/+7vuX33/nUa6UGky+/Lknn3z+bDxqWL714XuflPKFdte7cev+yNDwS68+s3p3N186ZAL3TPYP/uGvC4/94e/8aG9vP5IacnvupSefxho9eWxoZ31vcn7g9//TW7/8na+u3Nr8jd/6wwsXl3RFOnvpzIvPne/1+NbWYand+OjjT29d+QzgCCAHAG+8/j9OjI3cXXn49//nvzEzMf7RO59atvPmL5396Eef/eO/97/zoD8+NxKLJeKRyNzSYq/T+aP//P31ze1ALDw2PhpLJWWhl2v5bquJBJw6NQ2e+N3v/3AgOpIamQzosV//6788PjJYKeeahabpN37+w/fzxUJmIgogJMqZZSuaJnzm+wwRgikhGEsyBUAcwPM45ghLHAAyqQgIPDg80Kx2C+WOx0Sr2bd9XwvpS0szfr9vWZ7X9+q1ysr9hxcunpuZHhESOTo4vPjM6Xt372892F46sfjJR59xn/2FX/8lp97PHVZSA4MYs1BE+8EfvT02N2rIgX/zf/wGgPbVr/9qIhL7+MMPI8ngqdPznXodIendn71fatRTiUQwGPnat75otnq25Z28eGJ6anx3Y/coXzId89nnz69cX/uTH/58YGzk08t3C7XWhUvLqUgknoxyz5tfGKwXKpoWfPLFc0erxR/+8O3AkHbtw8vnnj4bTWZ/97+8bbrKG29e/Jt/55dmJybX7t9BKvzG//M/fPDT97tQD6Lom19+4dyZc7/3W2/d3P8MAJ49+ZJpdR1XHD81t3B85MHNh7/34+/B/7/1xde/HA0nxqZnz1xYTg1lB7LJZrmkSApBnu/xnZ39VCJc3W+99YOPHm3v/t1/9NeGMonPPvh04+F2MBaSFP3kxelm7oD+0q/9+mBsoNNpDQ4Orj/cuPbeDV84o6Pjxd3c+z975w9/8u8AvHPnvxVQAmvrj1RZ+eLXXpwdG2/l6lc+/qTYOMpMj46OjPnIzu2W+qb35MvnR+bjdVYMLaZePb+IkD8ykhG+u7mWf7T5cG1je2x0iiA2fXzp/IVzgVh0aFpwDQnAr7z+9MVLJzuFLsMO1pWXv/j8YCzz3e/++ManK8f+5f84MplJZZInLxzXQsG5E3O/8hdfY4jevHFnmTmLx8c++fTOD99+q2uJRCbz1V/55uHR3sbOQ+55v/Y3vhmXjde+eD4YlDUVT84Prty9Z8ghOahMnh3bPzxMZdKartXatfx+bmRsfP7kMZv4jusRGQgGSRapgaQPLgIezaTMnnnx/Pmxmfl2Xyiq5DC/VikVi9XcztbSqaVINmExCzGqqNRx+lhVgBLf92zflUDTJEXXDMd3AQTCmBLCkQ/CVxTiuigbT4GLJBxArE9lsnxiYmf/YGdjL6TJn//Sy51W59//8/8wPDfpudbM4vRrrz778NEa881qsXGUL80eWzzzxJOReKpQLAQj6YmpWCSR+0+//QelUuXk6UXbc2dnFmo1G8AAyK8/3PyVX/nSo/WQaZqtVmdyemrz0brjsenpiX7fHJsa6XQbvW6bYGkokxhOJvIHe61ePRKNTI4M3bi8Mjg5EY5Gw9mBQq3eNP1QVEgBIgtJU5TBkaRhyIkQJoPpYqF5/dM/+Vt/629PDA2ureekaPjszGw6m2iWTSvVmxibuPPglun31IDS7ZHs5IBt83qtduzEeDQVLOSqub2DTrd14dLFuem5ADKGB6befPL13e1du2vu9HN/PgHe+un3AdDJqbOl/b2p6cloNPTEMxd1XU4lst1eDxOhG4FWy5w6MTg0n9K4arWsxEDaQ1zVpW7b9Ls8EozTRGYYHNRotPZ39lbv3bm2cv3cwjlFptevXD6q5ycGTlebreefujQ4kMU/ce7evadoRNLx9Zu3d/KbAjkSQVPjI8jr5/Y3R0bSmYjR6dTKtYNWuXL2xLFy6fAwv7W6fqPd6jiejQTGCF743NMXLX773r29zUfnLp0en8p0uy231y9Xm7VGMZkNywE1EVEcs1Gv79NA24gotUqzb+kbj9bee/+d8bHJ+eOznHm7IbG1fud3/wvdzR1Kwo2oRjIYeOG1J5i9dO92JpaKTQ9nIxHDtWLFcnlrb3dvfzOaDmhBWVbkYFqhJcbAyueL5aNiPBkLhQPTk8MHu3t7e3s79UY5UFEDBgJOsRieyARDQbPfiSbDx44trjxY7XVq+wfr672eb5mpgYgRMGRCHeQxk+khCXMJwOdMMCEQQowzn3Pf94AJj3kcCUXWdF1iNseAPNN1Qr7pOwKTcCpUa7ct31UIPjjcIJKF+UuU8PPPLQUCof1dShVm9lvg81LpKJ1FwWBgc2vb9lgmkWGW+P3/+PsMeEjXike5RqvlstmJuUlBJCZMBWmOUIcG4xyLUFSz/U48Fe91exyEQBCJhVuNZshQJYQa9aqq6+Gw0em3G61qq1lJpxO1dpeqRFbJ+sYjDlZ2aHAwFYpEjFa9zn0fE3Mqk04loswXFnGkAS1cTMmybESiPl+bHB6YmogfHux43lkPcczxwW7ZZqbJuiFVHUzH3H7TULVXvvTC/oOtzfWjDz58p8Yq5aOcbS5mBpIDNClpC8+9/OTh1v67P/n5VmNfBWj/X1kg7m7fuLt9AwCSEDx+8vjrrz2/dOF0QFXDqQRCLqJien7EtftC7vUY4siNRkOEILPX67V7iUyIlhvm3s2rpWrzaO/w2sptgIbr+IqsHhRy2YGBF195wfL8k6eXgsFA6T/k642VzbWN0WzW6Vs90wTBUslgIh0ol1RFV5Jxw7Ib1SP87ofvR0MxNaQMaYO22a82y41mY+nEUr9vywadmholWN073LJtc3ZxjrrSjbvXL3/2cc/sZJPZufnFn/z03c8+/XR4aPirX//C3t6WoqFao0Rl/Pa7b3vexr17keNn/ncs5O2d7Wa7/kvf+kqpWu3ZDUnyI1GciEvDmWOnz09IEgkaVHAnEg2+++6HxUKtWCt85zu/1O5bWEGGoXrcrhWLtufcuXt1ZnoeE6RFDMbtSqXkQ6XaCAKQkJE8tjw3PDikENTsNCqlqkowBS+/f3BiebHfaiKChoaGAHlURsxnQjCMCJUlDIz5jFDMBUGCCeF4DJjPGfOJRDAwjIET4TOHAW51mlwQLajJquQUzDapFUqHzfbOUYGWKqWx8dHnX3rBcdyHW2uP1lYPtjfXH+74nA8OTu1t7V++/gGAMTY6Bczbz98BoCeXnh6aGl9OJi5dvGjoocN83gZzenFkePjSm196qddphaNBIL4q02a95jGXSEIwJoSHkKcq1HFs7vuW3csmU5ZtV+q1IbtfKJYFgVK5UK1VJCP0yuvPBYOy3e/vbh112x1umzrFik4VzehZzuBAQkKLrtnJF4uWZz/xxJmjwgGANzyZlah8eLi/vr7Rq9t9qzs2MnpqeUmWicD+qXNnpqZGyI8+unZNhh60nc69e9e3Nh8KJKamp2fnFxCI7HB89uRkJpW5feXGo8MtE8D7cxdCFbrv3/3Udvp//NbPRiYHn33h6cnJmUQiqshBm5kSUfrdDmMoFE0y36VSywGuB8L0f/kf/nbfaiuqIVEqg+KCnkolTLvxpS++MTo/EYnE7q+s3Lp+KxyN12stAKjVmq1eO51KDqRT+f3DQEDxnPbSsbnr125evXbtr/zV77iuEzQiY6PjBqWBeDqVjE3NT1pde2Zh8tHaxrVr1w8Lh5cuPJ0Ih1fvPbpx+XpQCb/39ntr9x7Uy82/9Gu/Eg3Gbnz68Y1PnVc+983nnjl35vSyY
2022-06-28 00:58:17 +00:00
"text/plain": [
"<PIL.Image.Image image mode=RGB size=256x256 at 0x7F7DCA30DC10>"
]
2022-06-28 00:58:17 +00:00
},
"metadata": {},
"output_type": "display_data"
2022-06-29 17:55:23 +00:00
},
{
"name": "stdout",
"output_type": "stream",
2022-06-29 17:55:23 +00:00
"text": [
"CPU times: user 7.38 s, sys: 11.8 ms, total: 7.39 s\n",
"Wall time: 7.34 s\n"
2022-06-29 17:55:23 +00:00
]
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 17:55:23 +00:00
"%%time\n",
"\n",
2022-07-01 01:23:15 +00:00
"text = \"cctv of yoda robbing a liquor store\" #@param {type:\"string\"}\n",
"seed = 2 #@param {type:\"integer\"}\n",
2022-06-28 15:05:59 +00:00
"\n",
2022-07-01 01:23:15 +00:00
"display(model.generate_image(text, seed))"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
"include_colab_link": true,
2022-06-30 15:25:24 +00:00
"name": "min-dalle",
"provenance": []
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}