generate_image_stream
This commit is contained in:
parent
cf5b116284
commit
5f4815775b
2
cog.yaml
vendored
2
cog.yaml
vendored
|
@ -10,4 +10,4 @@ build:
|
||||||
run:
|
run:
|
||||||
- pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
- pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
|
||||||
predict: "replicate/predict.py:Predictor"
|
predict: "cogrun.py:Predictor"
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
|
from min_dalle import MinDalle
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from typing import Iterator
|
||||||
from cog import BasePredictor, Path, Input
|
from cog import BasePredictor, Path, Input
|
||||||
|
|
||||||
from min_dalle import MinDalle
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
class Predictor(BasePredictor):
|
class Predictor(BasePredictor):
|
||||||
def setup(self):
|
def setup(self):
|
||||||
|
@ -30,19 +30,16 @@ class Predictor(BasePredictor):
|
||||||
le=4,
|
le=4,
|
||||||
default=3
|
default=3
|
||||||
),
|
),
|
||||||
) -> Path:
|
) -> Iterator[Path]:
|
||||||
|
image_stream = self.model.generate_image_stream(
|
||||||
def handle_intermediate_image(i: int, image: Image.Image):
|
|
||||||
if i + 1 == 16: return
|
|
||||||
out_path = Path(tempfile.mkdtemp()) / 'output.jpg'
|
|
||||||
image.save(str(out_path))
|
|
||||||
|
|
||||||
image = self.model.generate_image(
|
|
||||||
text,
|
text,
|
||||||
seed,
|
seed,
|
||||||
grid_size=grid_size,
|
grid_size=grid_size,
|
||||||
log2_mid_count=log2_intermediate_image_count,
|
log2_mid_count=log2_intermediate_image_count,
|
||||||
handle_intermediate_image=handle_intermediate_image
|
is_verbose=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return handle_intermediate_image(-1, image)
|
for image in image_stream:
|
||||||
|
out_path = Path(tempfile.mkdtemp()) / 'output.jpg'
|
||||||
|
image.save(str(out_path))
|
||||||
|
yield out_path
|
|
@ -1,6 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from matplotlib.pyplot import grid
|
||||||
from min_dalle import MinDalle
|
from min_dalle import MinDalle
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +14,6 @@ parser.add_argument('--seed', type=int, default=-1)
|
||||||
parser.add_argument('--grid-size', type=int, default=1)
|
parser.add_argument('--grid-size', type=int, default=1)
|
||||||
parser.add_argument('--image-path', type=str, default='generated')
|
parser.add_argument('--image-path', type=str, default='generated')
|
||||||
parser.add_argument('--models-root', type=str, default='pretrained')
|
parser.add_argument('--models-root', type=str, default='pretrained')
|
||||||
parser.add_argument('--row-count', type=int, default=16) # for debugging
|
|
||||||
|
|
||||||
|
|
||||||
def ascii_from_image(image: Image.Image, size: int = 128) -> str:
|
def ascii_from_image(image: Image.Image, size: int = 128) -> str:
|
||||||
|
@ -40,8 +40,7 @@ def generate_image(
|
||||||
seed: int,
|
seed: int,
|
||||||
grid_size: int,
|
grid_size: int,
|
||||||
image_path: str,
|
image_path: str,
|
||||||
models_root: str,
|
models_root: str
|
||||||
row_count: int
|
|
||||||
):
|
):
|
||||||
model = MinDalle(
|
model = MinDalle(
|
||||||
is_mega=is_mega,
|
is_mega=is_mega,
|
||||||
|
@ -50,21 +49,9 @@ def generate_image(
|
||||||
is_verbose=True
|
is_verbose=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if row_count < 16:
|
image = model.generate_image(text, seed, grid_size, is_verbose=True)
|
||||||
token_count = 16 * row_count
|
save_image(image, image_path)
|
||||||
image_tokens = model.generate_image_tokens(
|
print(ascii_from_image(image, size=128))
|
||||||
text,
|
|
||||||
seed,
|
|
||||||
grid_size ** 2,
|
|
||||||
row_count,
|
|
||||||
is_verbose=True
|
|
||||||
)
|
|
||||||
image_tokens = image_tokens[:, :token_count].to('cpu').detach().numpy()
|
|
||||||
print('image tokens', image_tokens)
|
|
||||||
else:
|
|
||||||
image = model.generate_image(text, seed, grid_size, is_verbose=True)
|
|
||||||
save_image(image, image_path)
|
|
||||||
print(ascii_from_image(image, size=128))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -76,6 +63,5 @@ if __name__ == '__main__':
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
grid_size=args.grid_size,
|
grid_size=args.grid_size,
|
||||||
image_path=args.image_path,
|
image_path=args.image_path,
|
||||||
models_root=args.models_root,
|
models_root=args.models_root
|
||||||
row_count=args.row_count
|
|
||||||
)
|
)
|
|
@ -5,7 +5,7 @@ from torch import LongTensor, FloatTensor
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple, Iterator
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_num_threads(os.cpu_count())
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
|
||||||
|
@ -159,16 +159,14 @@ class MinDalle:
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def generate_image_tokens(
|
def generate_image_stream(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
grid_size: int,
|
grid_size: int,
|
||||||
row_count: int,
|
|
||||||
log2_mid_count: int = 0,
|
log2_mid_count: int = 0,
|
||||||
handle_intermediate_image: Callable[[int, Image.Image], None] = None,
|
|
||||||
is_verbose: bool = False
|
is_verbose: bool = False
|
||||||
) -> LongTensor:
|
) -> Iterator[Image.Image]:
|
||||||
if is_verbose: print("tokenizing text")
|
if is_verbose: print("tokenizing text")
|
||||||
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
|
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
|
||||||
if is_verbose: print("text tokens", tokens)
|
if is_verbose: print("text tokens", tokens)
|
||||||
|
@ -196,6 +194,7 @@ class MinDalle:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
row_count = 16
|
||||||
for row_index in range(row_count):
|
for row_index in range(row_count):
|
||||||
if is_verbose:
|
if is_verbose:
|
||||||
print('sampling row {} of {}'.format(row_index + 1, row_count))
|
print('sampling row {} of {}'.format(row_index + 1, row_count))
|
||||||
|
@ -206,13 +205,10 @@ class MinDalle:
|
||||||
attention_state,
|
attention_state,
|
||||||
image_tokens
|
image_tokens
|
||||||
)
|
)
|
||||||
if handle_intermediate_image is not None and log2_mid_count > 0:
|
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
|
||||||
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
|
tokens = image_tokens[:, 1:]
|
||||||
tokens = image_tokens[:, 1:]
|
image = self.image_from_tokens(grid_size, tokens, is_verbose)
|
||||||
image = self.image_from_tokens(grid_size, tokens, is_verbose)
|
yield image
|
||||||
handle_intermediate_image(row_index, image)
|
|
||||||
|
|
||||||
return image_tokens[:, 1:]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image(
|
def generate_image(
|
||||||
|
@ -220,17 +216,14 @@ class MinDalle:
|
||||||
text: str,
|
text: str,
|
||||||
seed: int = -1,
|
seed: int = -1,
|
||||||
grid_size: int = 1,
|
grid_size: int = 1,
|
||||||
log2_mid_count: int = None,
|
|
||||||
handle_intermediate_image: Callable[[Image.Image], None] = None,
|
|
||||||
is_verbose: bool = False
|
is_verbose: bool = False
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
image_tokens = self.generate_image_tokens(
|
log2_mid_count = 0
|
||||||
|
image_stream = self.generate_image_stream(
|
||||||
text,
|
text,
|
||||||
seed,
|
seed,
|
||||||
grid_size,
|
grid_size,
|
||||||
row_count = 16,
|
log2_mid_count,
|
||||||
log2_mid_count = log2_mid_count,
|
is_verbose
|
||||||
handle_intermediate_image = handle_intermediate_image,
|
|
||||||
is_verbose = is_verbose
|
|
||||||
)
|
)
|
||||||
return self.image_from_tokens(grid_size, image_tokens, is_verbose)
|
return next(image_stream)
|
Loading…
Reference in New Issue
Block a user