generate_image_stream

This commit is contained in:
Brett Kuprel 2022-07-04 20:02:33 -04:00
parent cf5b116284
commit 5f4815775b
4 changed files with 32 additions and 56 deletions

2
cog.yaml vendored
View File

@ -10,4 +10,4 @@ build:
run: run:
- pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html - pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
predict: "replicate/predict.py:Predictor" predict: "cogrun.py:Predictor"

View File

@ -1,8 +1,8 @@
from min_dalle import MinDalle
import tempfile import tempfile
from typing import Iterator
from cog import BasePredictor, Path, Input from cog import BasePredictor, Path, Input
from min_dalle import MinDalle
from PIL import Image
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self): def setup(self):
@ -30,19 +30,16 @@ class Predictor(BasePredictor):
le=4, le=4,
default=3 default=3
), ),
) -> Path: ) -> Iterator[Path]:
image_stream = self.model.generate_image_stream(
def handle_intermediate_image(i: int, image: Image.Image):
if i + 1 == 16: return
out_path = Path(tempfile.mkdtemp()) / 'output.jpg'
image.save(str(out_path))
image = self.model.generate_image(
text, text,
seed, seed,
grid_size=grid_size, grid_size=grid_size,
log2_mid_count=log2_intermediate_image_count, log2_mid_count=log2_intermediate_image_count,
handle_intermediate_image=handle_intermediate_image is_verbose=True
) )
return handle_intermediate_image(-1, image) for image in image_stream:
out_path = Path(tempfile.mkdtemp()) / 'output.jpg'
image.save(str(out_path))
yield out_path

View File

@ -1,6 +1,7 @@
import argparse import argparse
import os import os
from PIL import Image from PIL import Image
from matplotlib.pyplot import grid
from min_dalle import MinDalle from min_dalle import MinDalle
@ -13,7 +14,6 @@ parser.add_argument('--seed', type=int, default=-1)
parser.add_argument('--grid-size', type=int, default=1) parser.add_argument('--grid-size', type=int, default=1)
parser.add_argument('--image-path', type=str, default='generated') parser.add_argument('--image-path', type=str, default='generated')
parser.add_argument('--models-root', type=str, default='pretrained') parser.add_argument('--models-root', type=str, default='pretrained')
parser.add_argument('--row-count', type=int, default=16) # for debugging
def ascii_from_image(image: Image.Image, size: int = 128) -> str: def ascii_from_image(image: Image.Image, size: int = 128) -> str:
@ -40,8 +40,7 @@ def generate_image(
seed: int, seed: int,
grid_size: int, grid_size: int,
image_path: str, image_path: str,
models_root: str, models_root: str
row_count: int
): ):
model = MinDalle( model = MinDalle(
is_mega=is_mega, is_mega=is_mega,
@ -50,21 +49,9 @@ def generate_image(
is_verbose=True is_verbose=True
) )
if row_count < 16: image = model.generate_image(text, seed, grid_size, is_verbose=True)
token_count = 16 * row_count save_image(image, image_path)
image_tokens = model.generate_image_tokens( print(ascii_from_image(image, size=128))
text,
seed,
grid_size ** 2,
row_count,
is_verbose=True
)
image_tokens = image_tokens[:, :token_count].to('cpu').detach().numpy()
print('image tokens', image_tokens)
else:
image = model.generate_image(text, seed, grid_size, is_verbose=True)
save_image(image, image_path)
print(ascii_from_image(image, size=128))
if __name__ == '__main__': if __name__ == '__main__':
@ -76,6 +63,5 @@ if __name__ == '__main__':
seed=args.seed, seed=args.seed,
grid_size=args.grid_size, grid_size=args.grid_size,
image_path=args.image_path, image_path=args.image_path,
models_root=args.models_root, models_root=args.models_root
row_count=args.row_count
) )

View File

@ -5,7 +5,7 @@ from torch import LongTensor, FloatTensor
import torch import torch
import json import json
import requests import requests
from typing import Callable, Tuple from typing import Callable, Tuple, Iterator
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count()) torch.set_num_threads(os.cpu_count())
@ -159,16 +159,14 @@ class MinDalle:
return image return image
def generate_image_tokens( def generate_image_stream(
self, self,
text: str, text: str,
seed: int, seed: int,
grid_size: int, grid_size: int,
row_count: int,
log2_mid_count: int = 0, log2_mid_count: int = 0,
handle_intermediate_image: Callable[[int, Image.Image], None] = None,
is_verbose: bool = False is_verbose: bool = False
) -> LongTensor: ) -> Iterator[Image.Image]:
if is_verbose: print("tokenizing text") if is_verbose: print("tokenizing text")
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
if is_verbose: print("text tokens", tokens) if is_verbose: print("text tokens", tokens)
@ -196,6 +194,7 @@ class MinDalle:
) )
) )
row_count = 16
for row_index in range(row_count): for row_index in range(row_count):
if is_verbose: if is_verbose:
print('sampling row {} of {}'.format(row_index + 1, row_count)) print('sampling row {} of {}'.format(row_index + 1, row_count))
@ -206,13 +205,10 @@ class MinDalle:
attention_state, attention_state,
image_tokens image_tokens
) )
if handle_intermediate_image is not None and log2_mid_count > 0: if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: tokens = image_tokens[:, 1:]
tokens = image_tokens[:, 1:] image = self.image_from_tokens(grid_size, tokens, is_verbose)
image = self.image_from_tokens(grid_size, tokens, is_verbose) yield image
handle_intermediate_image(row_index, image)
return image_tokens[:, 1:]
def generate_image( def generate_image(
@ -220,17 +216,14 @@ class MinDalle:
text: str, text: str,
seed: int = -1, seed: int = -1,
grid_size: int = 1, grid_size: int = 1,
log2_mid_count: int = None,
handle_intermediate_image: Callable[[Image.Image], None] = None,
is_verbose: bool = False is_verbose: bool = False
) -> Image.Image: ) -> Image.Image:
image_tokens = self.generate_image_tokens( log2_mid_count = 0
image_stream = self.generate_image_stream(
text, text,
seed, seed,
grid_size, grid_size,
row_count = 16, log2_mid_count,
log2_mid_count = log2_mid_count, is_verbose
handle_intermediate_image = handle_intermediate_image,
is_verbose = is_verbose
) )
return self.image_from_tokens(grid_size, image_tokens, is_verbose) return next(image_stream)