From cfb9f60b6ec276de88385361b493711c1c7fa7b2 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Sun, 10 Jul 2022 13:23:42 -0400 Subject: [PATCH] dtype dropdown in colab --- README.md | 2 +- min_dalle.ipynb | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0b973c1..f0735ed 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@   [![Discord](https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white)](https://discord.com/channels/823813159592001537/912729332311556136) -This is a fast, minimal port of Boris Dayma's [DALL·E Mega](https://github.com/borisdayma/dalle-mini). It has been stripped down for inference and converted to PyTorch. The only third party dependencies are numpy, requests, pillow and torch. +This is a fast, minimal port of [DALL·E Mega](https://github.com/borisdayma/dalle-mini). It has been stripped down for inference and converted to PyTorch. The only third party dependencies are numpy, requests, pillow and torch. To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a T4 in Colab diff --git a/min_dalle.ipynb b/min_dalle.ipynb index 877d6e1..f8236f9 100644 --- a/min_dalle.ipynb +++ b/min_dalle.ipynb @@ -102,7 +102,7 @@ }, "source": [ "### Load Model\n", - "Float32 is faster but uses more GPU memory. Change the `grid_size` to 3 or less if using float32." + "`float32` is faster than `float16` but uses more GPU memory. Change the `grid_size` to 3 or less if using `float32`." ] }, { @@ -128,13 +128,14 @@ } ], "source": [ + "dtype = \"float32\" #@param [\"float32\", \"float16\", \"bfloat16\"]\n", "from IPython.display import display, update_display\n", "from math import log2\n", "import torch\n", "from min_dalle import MinDalle\n", "\n", "model = MinDalle(\n", - " dtype=torch.float16,\n", + " dtype=getattr(torch, dtype),\n", " is_mega=True, \n", " is_reusable=True\n", ")"