generate_image_stream
This commit is contained in:
@@ -5,7 +5,7 @@ from torch import LongTensor, FloatTensor
|
||||
import torch
|
||||
import json
|
||||
import requests
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable, Tuple, Iterator
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_num_threads(os.cpu_count())
|
||||
|
||||
@@ -159,16 +159,14 @@ class MinDalle:
|
||||
return image
|
||||
|
||||
|
||||
def generate_image_tokens(
|
||||
def generate_image_stream(
|
||||
self,
|
||||
text: str,
|
||||
seed: int,
|
||||
grid_size: int,
|
||||
row_count: int,
|
||||
log2_mid_count: int = 0,
|
||||
handle_intermediate_image: Callable[[int, Image.Image], None] = None,
|
||||
is_verbose: bool = False
|
||||
) -> LongTensor:
|
||||
) -> Iterator[Image.Image]:
|
||||
if is_verbose: print("tokenizing text")
|
||||
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
|
||||
if is_verbose: print("text tokens", tokens)
|
||||
@@ -196,6 +194,7 @@ class MinDalle:
|
||||
)
|
||||
)
|
||||
|
||||
row_count = 16
|
||||
for row_index in range(row_count):
|
||||
if is_verbose:
|
||||
print('sampling row {} of {}'.format(row_index + 1, row_count))
|
||||
@@ -206,13 +205,10 @@ class MinDalle:
|
||||
attention_state,
|
||||
image_tokens
|
||||
)
|
||||
if handle_intermediate_image is not None and log2_mid_count > 0:
|
||||
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
|
||||
tokens = image_tokens[:, 1:]
|
||||
image = self.image_from_tokens(grid_size, tokens, is_verbose)
|
||||
handle_intermediate_image(row_index, image)
|
||||
|
||||
return image_tokens[:, 1:]
|
||||
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
|
||||
tokens = image_tokens[:, 1:]
|
||||
image = self.image_from_tokens(grid_size, tokens, is_verbose)
|
||||
yield image
|
||||
|
||||
|
||||
def generate_image(
|
||||
@@ -220,17 +216,14 @@ class MinDalle:
|
||||
text: str,
|
||||
seed: int = -1,
|
||||
grid_size: int = 1,
|
||||
log2_mid_count: int = None,
|
||||
handle_intermediate_image: Callable[[Image.Image], None] = None,
|
||||
is_verbose: bool = False
|
||||
) -> Image.Image:
|
||||
image_tokens = self.generate_image_tokens(
|
||||
text,
|
||||
seed,
|
||||
grid_size,
|
||||
row_count = 16,
|
||||
log2_mid_count = log2_mid_count,
|
||||
handle_intermediate_image = handle_intermediate_image,
|
||||
is_verbose = is_verbose
|
||||
log2_mid_count = 0
|
||||
image_stream = self.generate_image_stream(
|
||||
text,
|
||||
seed,
|
||||
grid_size,
|
||||
log2_mid_count,
|
||||
is_verbose
|
||||
)
|
||||
return self.image_from_tokens(grid_size, image_tokens, is_verbose)
|
||||
return next(image_stream)
|
Reference in New Issue
Block a user