min-dalle-test/min_dalle.ipynb

168 lines
170 KiB
Plaintext
Raw Normal View History

2022-06-28 00:58:17 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
2022-06-29 02:12:36 +00:00
"id": "view-in-github",
"colab_type": "text"
2022-06-28 00:58:17 +00:00
},
"source": [
"<a href=\"https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3WL-G_f2_ld8"
},
"source": [
"# min(DALL·E)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zl_ZFisFApeh"
},
"source": [
"### Download models and install dependencies"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "code",
2022-06-28 16:16:44 +00:00
"execution_count": null,
2022-06-28 00:58:17 +00:00
"metadata": {
2022-06-28 16:16:44 +00:00
"id": "ix_xt4X1_6F4",
2022-06-29 02:12:36 +00:00
"cellView": "code"
2022-06-28 00:58:17 +00:00
},
2022-06-28 16:16:44 +00:00
"outputs": [],
2022-06-28 00:58:17 +00:00
"source": [
"! git clone https://github.com/kuprel/min-dalle\n",
2022-06-29 02:12:36 +00:00
"! mkdir -p /content/min-dalle/pretrained/vqgan/\n",
"! curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output /content/min-dalle/pretrained/vqgan/flax_model.msgpack\n",
2022-06-28 15:01:31 +00:00
"! pip install torch flax==0.4.2 wandb\n",
"! wandb login --anonymously\n",
2022-06-29 10:44:07 +00:00
"! wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0\n",
"! wandb artifact get --root=/content/min-dalle/pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14\n"
2022-06-28 00:58:17 +00:00
]
},
{
"cell_type": "markdown",
"source": [
"### Load Model"
],
"metadata": {
"id": "kViq2dMbGDKt"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"os.chdir('/content/min-dalle')\n",
"from min_dalle.min_dalle_torch import MinDalleTorch\n",
"from min_dalle.min_dalle_flax import MinDalleFlax\n",
"\n",
"mega = False #@param {type:\"boolean\"}\n",
"torch = True #@param {type:\"boolean\"}\n",
"\n",
"model = MinDalleTorch(mega) if torch else MinDalleFlax(mega)\n"
],
"metadata": {
"id": "8W-L2ICFGFup",
"outputId": "4ec5f57a-dd63-4dea-894a-cc17f5758ac7",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"reading files from pretrained/dalle_bart_mini\n",
"initializing MinDalleTorch\n",
"loading encoder\n",
"loading decoder\n"
]
}
]
},
2022-06-28 00:58:17 +00:00
{
"cell_type": "markdown",
2022-06-28 15:01:31 +00:00
"metadata": {
2022-06-28 15:05:59 +00:00
"id": "c52TV1GbBNgS"
},
"source": [
"### Generate an Image"
]
2022-06-28 00:58:17 +00:00
},
{
"cell_type": "code",
"execution_count": 15,
2022-06-28 00:58:17 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 392
2022-06-28 15:05:59 +00:00
},
"id": "nQ0UG05dA4p2",
"outputId": "dde449ae-1301-4316-aaa7-4eb249fd88fe"
2022-06-28 00:58:17 +00:00
},
"outputs": [
{
2022-06-28 16:16:44 +00:00
"output_type": "stream",
2022-06-29 02:12:36 +00:00
"name": "stdout",
2022-06-28 00:58:17 +00:00
"text": [
"tokenizing text\n",
2022-06-29 02:12:36 +00:00
"['Ġartificial']\n",
"['Ġintelligence']\n",
"text tokens [0, 6316, 7815, 2]\n",
2022-06-28 00:58:17 +00:00
"encoding text tokens\n",
"sampling image tokens\n",
"detokenizing image\n"
]
},
{
2022-06-29 02:12:36 +00:00
"output_type": "display_data",
2022-06-28 00:58:17 +00:00
"data": {
"text/plain": [
"<PIL.Image.Image image mode=RGB size=256x256 at 0x7FD8EE306550>"
2022-06-29 02:12:36 +00:00
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAEAAElEQVR4nDz9Z9y22VUXDP/X2nsf7WxXL3e/p7dkkkmbJKQBCYT6IAEVEURQxPdR8VF8/Vmx/NRHEXn1QRRUpAcIQVp6zGQyKTOZ3u977t6ufp39aHvvtd4P1/Ccn45Px/llH2uv9W+LegADhiECxus/IoCgDgQDhQIEVQIpRYAjhBjGKcQaUhEiJhgQkyICpCrKRGJZoTEQk8kMGRVSC2ajAtVoVAVGXUFK0MAkIAN2xABIAJAlNSQtWxEkAZl1mYCFLFhE1DIhNi5xIhB2bJ0gCqwxJoZATD4Ycql1rvWeiUFGNGaOgqAVZHkn1mpNaMFBDUhDE5M8n9dSZEZAUdJOx6bSTod7ddto0k0SVlGB73QKa3hyOKqmcvzeOzqIF199MWoNLlxR+LIyeZobXlvfKCfVwfbFTj9Lk6Kl9Oy997741cfb0bhzdnM+b0BJb1BMD6oPfv/379+4+cwffqp/vDepA9IV1899WSf9laLoBmrzBbu3U2FUIQSECt0B5iX6DpQPllfj5Mrs8HreX1nbOHH1xYv5cl5NI5o5shwgQJAvrB3f8MFa0GCg072ruzduZKkjlTLYLKV6Vi6ur7bToXitYnf99B2pTA73DohZ0SAxeZbu7+wZVo02Ec1trPy8DWwT6zIzr+bOSmZMPZ6fOHV6dLg3Gc1iCFknq+Zl6pRMx6MYLK+k2dLO1Wvix0knaauxyzOjs7rp4PjDaD1295EbyAjjvez4mp+PVYKoQxNBCeDhp5gPESsooDNnCh8PgMiAAQMKqIIUymCCCsgwgwwIBJIQmaxSBKnLLCcpGIgAETzggUBQA7FGKI1iRZ2Ki5JIcFHSxidBsyhpGzrOrsAXpN1WsyCpj9aL9ZJ4SQNnLYpAqVcXTRaRQFNoEiVtYxIpaWFb4kCmigncApCCkpYRWIPGCImkXiEmihUxQW2DxEcbxVTEU0FNJMxqDBNHIBCcQQ6kzKmFIRVFAItj48hYK2Q9cZ0m4hxgWk5JC+4siEumXneJJsZVRScmruwVPu2pmoCO65zYDLQbwo64gLRRNAizXm6y1OaFrdoJVhfM8sb6sZ6TChxgKmLfWyyKQWf59OaJd65rvzGpzfsc0RIl+7PQ+gxL6yH6Ts9kPd9KC5i/9KNvf+Dtmxt3b9ah7Qx6cMHP55Tz9/+Vd93zUH8wwLyZoku463Ys9XFi9a3f+dbunT2sL+Z33faXf/QtdrpTdDjppJo4dBOTBRQlL0m6IMV6ki3n+aptnbEbvTKlWhqbpc7VwCSahpJZ4DpZSMR56VBt2nSQ9ZcXFtdW1Iw5GTGaXtcRR5fOQ5yRiafOrHXSxDljCjUUDCE1vcKsJ6FYWdwgEx0okiAh0TZxgKhoNL2e7QyUyiyvbdp43S8y+MmeravllfzDf/fDK9/wBnzHBz/yK7+w/k3fhY37a2ljqKRtiOp8faVYWYdh+BYUoQxVwHoZAxFgBUBEzAEQkII8xIMDEEFeOSgFMCWpN6om9Sapo7EipCARYUNGiSAKJhh6/RFQEFFUMFMUpIYNEBGN8WkwZFwDnxgFogqUxMIQLIyoKjExOzATRCFilQREBji6L6yIJDYkaCRhDzFsiVjl6CpiNhZGVJyIgbMKQ2oMsYKYiJTBxDBRg1IQDaqeKcYQxTERA5G0UUqJGBojxCppFOHIobVGXOVgQ/BNCq6bOVrX+pbThNSENjpKYut5i6it1Qr7hsalzQtGnE+n0310lwoTZx1q9XBb1lrkYoKLodVYGRNG450422W/s97fmPIOiEzm5r7dOL704Xf9uS//9u+1KpZTI14d52uDxcHg8vbh4tLg+AN3bm2Ne9LOm9lCmqQHe/Ptm4PBYH5z7BYHa/ff//f/4V9/8hNfe/r3v3r/yWOzxc7OKzcvPr67vFrcGo+qciycdzpsgk+Ye1mHDIu1FEI/QTveGWBelXUaJM0Ku3i81obZzauZRCGFTituQ88yo6LJXqOTrk0r9UwU596H2kUbORpr2qrqJN2ZVoUjFq6bNhsM3vbGd3/3d7335/7dz7TTQ82SJC4IWm2F0sCi0XOczimb+zhJM43RMceggUgjC/ly9yuf70zjNCzOVf7Bv/rR//FPF5/75C/ARMTIahAbjYxEbNGJZathBkRoCyXAAI4peo1GicACJRgDImJRUbBhGJgIhSAhFyQkhiOilaDKSoBEJVYBiDVKNGxEVTWSGAVACKIMVmgkigSw1iQOZAgCL8RkmEkgqiQgkMAoQY3ERom9eiggxGAASoggJSapApRYHUgRRRhMIGMpqjZQqALEFGFJiVqNomQtKZitKKIh8qpK1KhCtTHq4a0KAcxsIJ5NasBMxIgxBiEHQxJDbJs8zWeNs0XmCMY5DS0B0YtvmjTJythSN7HST+ZNstiZTw4VUtZ10e+3sa3Fpr0VHw8nu2LPnF3aPNnMq+l8PwYZj0cQU/qJO3XP0uoZ7tyEs5m1g4XNt7/lgQTN2dvXdw9GTaPzsllezHOTupa3ru3V1Wx+zY/3hxvrS0Vq4cdf/+Sj5/bKM2//BhXDiOuri7iKe+3SIze2Nu68TT3a8bbt31ebnnOtdc63Va+To47NJIDK6JwogWiwPLgx35vEyqfUmiwWvTBYK1BrW8FKVcWokZJkYXFpOtzO+tluM9rYXA6zln1jqFatTZTICZNYy5omh/Why7rj6dA3ImjK1n/+U3/w0quPNYfjEMqkv+BSrqe1lSDqvYdhiWW+c2EH7QgdyvI6Vg0bkycmeD/Z2bn48U8ofGO6j/rJ7CM/cuKDH3ru0T/AVKFDiaWMKkGJpgyUIk9MMojVDLUFGkCAqBoYGo96eEARFJZJSAMLiZKCiDmKgJ1lChACWVEwKEKtBcAQqBwVfRApq4kGhiiQkqRMFKMQYKN6QnDQRshAyVoiUgVTJBUJiEZVCMREzBRVieLRjaDKlgEhAlShKmDRVpSFmR0BBDCDQlQlEBkQEZG1NhG1kUSZYwSkFVLRiEiiUFZVISUFMaIGoypiCLFFGliIWQMCiWoTKFEE8aE1ScLG15OGVKNxCrVpnqX53Ji1E6emlRSrC6n1NjFqHLErBgst2SzPO51BQ5Dp+GD7Wq+7nhy/b/X4vZvHj6+sJqfPnF1zpm87fcMS3COPfOn67qVQVlnCs/1hvX3j5Zefun7pQhMIlObGsefpaDI5OAx18E0Zmjm144PtmSGTplRFbdHrnrj96qvPt7PZ819+8lN20lx69cbNGwezcRU0xvqZ53bPvu3dT3/2k6nVJkB7ncJi78ZOmg+C9wJO2BpHIYmj6bAEJPKm5f1ynBHJaMQEuMQ3YXV9ZanfO7x23edxtH3gJ1MNWo5mrE0ncVkv57ZiTSlP1++9/Sf/zk/92Lf97bX19Sb40cGwWxTohdFwP0kKQ0Vb1r2815ZzY9mKiYaseld03/Pe797bffqZxx8nMsyWSYjEmjRqbKPWqvB1PP/cl375Zx/4yN8q3vH28kuPwFTalO3RMBoMwhymjSbD0hJGMwfrfQVfE0AwQBSAXj9DxjAfnQoLEJGqKgOixEyqJGKJoArDUAAQIbAy1HEwxoQWBpBW4dQasqJ6VG6ViAwBABMpQSSSqBKBo0Y1ICLAxpgwDDgqKXGEZgww4tFQrQwFgUhEYECaRGVFUM9KDHYRAcyseP0vWMk20EisBEMqwgmzE4nEzgvYGRUAztkYQwsyTVOzTevpBGKMFZfnzLYlLfKOWj1z1/33v+u2xz/5lWTzBC8WmxtLnZS57diYlPPZZHhha69+w7ve19K09aVHQkQpotStqeqcqKyqV5/f6/R7dmn19vseOPeHfzJ54vlXqX3aZofD7SzPNjaWR+PJwd5+rBsNPltdK8fteH9+sNfW3qgKQ0ENMyVJKm2ZK8ogtTQ+tpSINUUU1G2ESQ/2RuLbvLNYpO6F565yVbadjpe6kxZF0b98cc+tlMYW3cWuonvnN37j+z7whj/8+d/Y3bpZGJstde99+5v+6o9/8HDf377iKuDi1ek9p
2022-06-28 00:58:17 +00:00
},
2022-06-29 02:12:36 +00:00
"metadata": {}
2022-06-28 00:58:17 +00:00
}
2022-06-28 15:05:59 +00:00
],
"source": [
2022-06-29 02:12:36 +00:00
"text = \"artificial intelligence\" #@param {type:\"string\"}\n",
2022-06-29 10:44:07 +00:00
"seed = 0 #@param {type:\"integer\"}\n",
2022-06-28 15:05:59 +00:00
"\n",
"image = model.generate_image(text, seed)\n",
2022-06-28 15:05:59 +00:00
"display(image)"
2022-06-28 00:58:17 +00:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
2022-06-28 15:05:59 +00:00
"collapsed_sections": [
"Zl_ZFisFApeh"
],
2022-06-28 16:16:44 +00:00
"name": "min-dalle",
2022-06-29 02:12:36 +00:00
"provenance": [],
"authorship_tag": "ABX9TyPYiD/1K6WVDkiSthQa8puM",
2022-06-29 02:12:36 +00:00
"include_colab_link": true
2022-06-28 00:58:17 +00:00
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
2022-06-29 02:12:36 +00:00
}