From 2cac9220b56c57ba878a8e2d3ffa4cb961a75dfc Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Thu, 7 Jul 2022 17:03:47 -0400 Subject: [PATCH] generate_images_stream and generate_images --- README.rst | 102 ----------------------------------------- cog.yaml | 2 +- min_dalle/min_dalle.py | 65 +++++++++++++++++++++++--- replicate_predictor.py | 2 + setup.py | 2 +- 5 files changed, 62 insertions(+), 111 deletions(-) delete mode 100644 README.rst diff --git a/README.rst b/README.rst deleted file mode 100644 index 8c44a0b..0000000 --- a/README.rst +++ /dev/null @@ -1,102 +0,0 @@ -min(DALL·E) -=========== - -|Colab|   |Replicate|   |Discord| - -This is a fast, minimal port of Boris Dayma’s `DALL·E -Mega `__. It has been stripped -down for inference and converted to PyTorch. The only third party -dependencies are numpy, requests, pillow and torch. - -To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a -T4 in Colab - 48 sec with a P100 in Colab - 14 sec with an A100 on -Replicate - -The flax model and code for converting it to torch can be found -`here `__. - -Install -------- - -.. code:: bash - - $ pip install min-dalle - -Usage ------ - -Load the model parameters once and reuse the model to generate multiple -images. - -.. code:: python - - from min_dalle import MinDalle - - model = MinDalle( - is_mega=True, - is_reusable=True, - models_root='./pretrained' - ) - -The required models will be downloaded to ``models_root`` if they are -not already there. Once everything has finished initializing, call -``generate_image`` with some text as many times as you want. Use a -positive ``seed`` for reproducible results. Higher values for -``log2_supercondition_factor`` result in better agreement with the text -but a narrower variety of generated images. Every image token is sampled -from the top-:math:`k` most probable tokens. - -.. code:: python - - image = model.generate_image( - text='Nuclear explosion broccoli', - seed=-1, - grid_size=4, - log2_k=6, - log2_supercondition_factor=5, - is_verbose=False - ) - - display(image) - -Interactive -~~~~~~~~~~~ - -If the model is being used interactively (e.g. in a notebook) -``generate_image_stream`` can be used to generate a stream of images as -the model is decoding. The detokenizer adds a slight delay for each -image. Setting ``log2_mid_count`` to 3 results in a total of -``2 ** 3 = 8`` generated images. The only valid values for -``log2_mid_count`` are 0, 1, 2, 3, and 4. This is implemented in the -colab. - -.. code:: python - - image_stream = model.generate_image_stream( - text='Dali painting of WALL·E', - seed=-1, - grid_size=3, - log2_mid_count=3, - log2_k=6, - log2_supercondition_factor=3, - is_verbose=False - ) - - for image in image_stream: - display(image) - -Command Line -~~~~~~~~~~~~ - -Use ``image_from_text.py`` to generate images from the command line. - -.. code:: bash - - $ python image_from_text.py --text='artificial intelligence' --no-mega - -.. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg - :target: https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb -.. |Replicate| image:: https://replicate.com/kuprel/min-dalle/badge - :target: https://replicate.com/kuprel/min-dalle -.. |Discord| image:: https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white - :target: https://discord.com/channels/823813159592001537/912729332311556136 diff --git a/cog.yaml b/cog.yaml index 61f36df..7ca0216 100644 --- a/cog.yaml +++ b/cog.yaml @@ -6,7 +6,7 @@ build: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.2.36" + - "min-dalle==0.3.1" run: - pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index 152146b..802adcb 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -1,7 +1,9 @@ import os from PIL import Image +from matplotlib.pyplot import grid import numpy from torch import LongTensor +from math import sqrt import torch import json import requests @@ -142,25 +144,29 @@ class MinDalle: if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() - def image_from_tokens( + def images_from_tokens( self, - grid_size: int, image_tokens: LongTensor, is_verbose: bool = False - ) -> Image.Image: + ) -> LongTensor: if not self.is_reusable: del self.decoder if torch.cuda.is_available(): torch.cuda.empty_cache() if not self.is_reusable: self.init_detokenizer() if is_verbose: print("detokenizing image") images = self.detokenizer.forward(image_tokens).to(torch.uint8) if not self.is_reusable: del self.detokenizer + return images + + + def grid_from_images(self, images: LongTensor) -> Image.Image: + grid_size = int(sqrt(images.shape[0])) images = images.reshape([grid_size] * 2 + list(images.shape[1:])) image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2) image = Image.fromarray(image.to('cpu').detach().numpy()) return image - def generate_image_stream( + def generate_images_stream( self, text: str, seed: int, @@ -169,7 +175,7 @@ class MinDalle: log2_k: int = 6, log2_supercondition_factor: int = 3, is_verbose: bool = False - ) -> Iterator[Image.Image]: + ) -> Iterator[LongTensor]: assert(log2_mid_count in range(5)) if is_verbose: print("tokenizing text") tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) @@ -219,8 +225,53 @@ class MinDalle: with torch.cuda.amp.autocast(dtype=torch.float32): if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: tokens = image_tokens[:, 1:] - image = self.image_from_tokens(grid_size, tokens, is_verbose) - yield image + images = self.images_from_tokens(tokens, is_verbose) + yield images + + + def generate_image_stream( + self, + text: str, + seed: int, + grid_size: int, + log2_mid_count: int, + log2_k: int = 6, + log2_supercondition_factor: int = 3, + is_verbose: bool = False + ) -> Iterator[Image.Image]: + images_stream = self.generate_images_stream( + text, + seed, + grid_size, + log2_mid_count, + log2_k, + log2_supercondition_factor, + is_verbose + ) + for images in images_stream: + yield self.grid_from_images(images) + + + def generate_images( + self, + text: str, + seed: int = -1, + grid_size: int = 1, + log2_k: int = 6, + log2_supercondition_factor: int = 3, + is_verbose: bool = False + ) -> LongTensor: + log2_mid_count = 0 + images_stream = self.generate_images_stream( + text, + seed, + grid_size, + log2_mid_count, + log2_k, + log2_supercondition_factor, + is_verbose + ) + return next(images_stream) def generate_image( diff --git a/replicate_predictor.py b/replicate_predictor.py index bc2efa3..5177c5d 100644 --- a/replicate_predictor.py +++ b/replicate_predictor.py @@ -1,5 +1,6 @@ from min_dalle import MinDalle import tempfile +import torch from typing import Iterator from cog import BasePredictor, Path, Input @@ -53,5 +54,6 @@ class ReplicatePredictor(BasePredictor): except: print("An error occured, deleting model") del self.model + torch.cuda.empty_cache() self.setup() raise Exception("There was an error, please try again") \ No newline at end of file diff --git a/setup.py b/setup.py index ae817f2..35b1b24 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALL·E)', # long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.2.36', + version='0.3.1', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',