From 5f4815775b5a7f78deeff23fb270efab7d50a934 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Mon, 4 Jul 2022 20:02:33 -0400 Subject: [PATCH] generate_image_stream --- cog.yaml | 2 +- replicate/predict.py => cogrun.py | 21 +++++++---------- image_from_text.py | 26 +++++---------------- min_dalle/min_dalle.py | 39 +++++++++++++------------------ 4 files changed, 32 insertions(+), 56 deletions(-) rename replicate/predict.py => cogrun.py (78%) diff --git a/cog.yaml b/cog.yaml index 2d5c88f..af25f52 100644 --- a/cog.yaml +++ b/cog.yaml @@ -10,4 +10,4 @@ build: run: - pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html -predict: "replicate/predict.py:Predictor" +predict: "cogrun.py:Predictor" diff --git a/replicate/predict.py b/cogrun.py similarity index 78% rename from replicate/predict.py rename to cogrun.py index 684274a..2da19ff 100644 --- a/replicate/predict.py +++ b/cogrun.py @@ -1,8 +1,8 @@ +from min_dalle import MinDalle import tempfile +from typing import Iterator from cog import BasePredictor, Path, Input -from min_dalle import MinDalle -from PIL import Image class Predictor(BasePredictor): def setup(self): @@ -30,19 +30,16 @@ class Predictor(BasePredictor): le=4, default=3 ), - ) -> Path: - - def handle_intermediate_image(i: int, image: Image.Image): - if i + 1 == 16: return - out_path = Path(tempfile.mkdtemp()) / 'output.jpg' - image.save(str(out_path)) - - image = self.model.generate_image( + ) -> Iterator[Path]: + image_stream = self.model.generate_image_stream( text, seed, grid_size=grid_size, log2_mid_count=log2_intermediate_image_count, - handle_intermediate_image=handle_intermediate_image + is_verbose=True ) - return handle_intermediate_image(-1, image) \ No newline at end of file + for image in image_stream: + out_path = Path(tempfile.mkdtemp()) / 'output.jpg' + image.save(str(out_path)) + yield out_path \ No newline at end of file diff --git a/image_from_text.py b/image_from_text.py index 666aa7b..82d908d 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -1,6 +1,7 @@ import argparse import os from PIL import Image +from matplotlib.pyplot import grid from min_dalle import MinDalle @@ -13,7 +14,6 @@ parser.add_argument('--seed', type=int, default=-1) parser.add_argument('--grid-size', type=int, default=1) parser.add_argument('--image-path', type=str, default='generated') parser.add_argument('--models-root', type=str, default='pretrained') -parser.add_argument('--row-count', type=int, default=16) # for debugging def ascii_from_image(image: Image.Image, size: int = 128) -> str: @@ -40,8 +40,7 @@ def generate_image( seed: int, grid_size: int, image_path: str, - models_root: str, - row_count: int + models_root: str ): model = MinDalle( is_mega=is_mega, @@ -50,21 +49,9 @@ def generate_image( is_verbose=True ) - if row_count < 16: - token_count = 16 * row_count - image_tokens = model.generate_image_tokens( - text, - seed, - grid_size ** 2, - row_count, - is_verbose=True - ) - image_tokens = image_tokens[:, :token_count].to('cpu').detach().numpy() - print('image tokens', image_tokens) - else: - image = model.generate_image(text, seed, grid_size, is_verbose=True) - save_image(image, image_path) - print(ascii_from_image(image, size=128)) + image = model.generate_image(text, seed, grid_size, is_verbose=True) + save_image(image, image_path) + print(ascii_from_image(image, size=128)) if __name__ == '__main__': @@ -76,6 +63,5 @@ if __name__ == '__main__': seed=args.seed, grid_size=args.grid_size, image_path=args.image_path, - models_root=args.models_root, - row_count=args.row_count + models_root=args.models_root ) \ No newline at end of file diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index b02a6f6..46bd1b8 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -5,7 +5,7 @@ from torch import LongTensor, FloatTensor import torch import json import requests -from typing import Callable, Tuple +from typing import Callable, Tuple, Iterator torch.set_grad_enabled(False) torch.set_num_threads(os.cpu_count()) @@ -159,16 +159,14 @@ class MinDalle: return image - def generate_image_tokens( + def generate_image_stream( self, text: str, seed: int, grid_size: int, - row_count: int, log2_mid_count: int = 0, - handle_intermediate_image: Callable[[int, Image.Image], None] = None, is_verbose: bool = False - ) -> LongTensor: + ) -> Iterator[Image.Image]: if is_verbose: print("tokenizing text") tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) if is_verbose: print("text tokens", tokens) @@ -196,6 +194,7 @@ class MinDalle: ) ) + row_count = 16 for row_index in range(row_count): if is_verbose: print('sampling row {} of {}'.format(row_index + 1, row_count)) @@ -206,13 +205,10 @@ class MinDalle: attention_state, image_tokens ) - if handle_intermediate_image is not None and log2_mid_count > 0: - if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: - tokens = image_tokens[:, 1:] - image = self.image_from_tokens(grid_size, tokens, is_verbose) - handle_intermediate_image(row_index, image) - - return image_tokens[:, 1:] + if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: + tokens = image_tokens[:, 1:] + image = self.image_from_tokens(grid_size, tokens, is_verbose) + yield image def generate_image( @@ -220,17 +216,14 @@ class MinDalle: text: str, seed: int = -1, grid_size: int = 1, - log2_mid_count: int = None, - handle_intermediate_image: Callable[[Image.Image], None] = None, is_verbose: bool = False ) -> Image.Image: - image_tokens = self.generate_image_tokens( - text, - seed, - grid_size, - row_count = 16, - log2_mid_count = log2_mid_count, - handle_intermediate_image = handle_intermediate_image, - is_verbose = is_verbose + log2_mid_count = 0 + image_stream = self.generate_image_stream( + text, + seed, + grid_size, + log2_mid_count, + is_verbose ) - return self.image_from_tokens(grid_size, image_tokens, is_verbose) \ No newline at end of file + return next(image_stream) \ No newline at end of file