|
|
|
@ -5,7 +5,7 @@ from torch import LongTensor, FloatTensor |
|
|
|
|
import torch |
|
|
|
|
import json |
|
|
|
|
import requests |
|
|
|
|
from typing import Callable, Tuple |
|
|
|
|
from typing import Callable, Tuple, Iterator |
|
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
torch.set_num_threads(os.cpu_count()) |
|
|
|
|
|
|
|
|
@ -159,16 +159,14 @@ class MinDalle: |
|
|
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image_tokens( |
|
|
|
|
def generate_image_stream( |
|
|
|
|
self, |
|
|
|
|
text: str, |
|
|
|
|
seed: int, |
|
|
|
|
grid_size: int, |
|
|
|
|
row_count: int, |
|
|
|
|
log2_mid_count: int = 0, |
|
|
|
|
handle_intermediate_image: Callable[[int, Image.Image], None] = None, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> LongTensor: |
|
|
|
|
) -> Iterator[Image.Image]: |
|
|
|
|
if is_verbose: print("tokenizing text") |
|
|
|
|
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) |
|
|
|
|
if is_verbose: print("text tokens", tokens) |
|
|
|
@ -196,6 +194,7 @@ class MinDalle: |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
row_count = 16 |
|
|
|
|
for row_index in range(row_count): |
|
|
|
|
if is_verbose: |
|
|
|
|
print('sampling row {} of {}'.format(row_index + 1, row_count)) |
|
|
|
@ -206,13 +205,10 @@ class MinDalle: |
|
|
|
|
attention_state, |
|
|
|
|
image_tokens |
|
|
|
|
) |
|
|
|
|
if handle_intermediate_image is not None and log2_mid_count > 0: |
|
|
|
|
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: |
|
|
|
|
tokens = image_tokens[:, 1:] |
|
|
|
|
image = self.image_from_tokens(grid_size, tokens, is_verbose) |
|
|
|
|
handle_intermediate_image(row_index, image) |
|
|
|
|
|
|
|
|
|
return image_tokens[:, 1:] |
|
|
|
|
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: |
|
|
|
|
tokens = image_tokens[:, 1:] |
|
|
|
|
image = self.image_from_tokens(grid_size, tokens, is_verbose) |
|
|
|
|
yield image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image( |
|
|
|
@ -220,17 +216,14 @@ class MinDalle: |
|
|
|
|
text: str, |
|
|
|
|
seed: int = -1, |
|
|
|
|
grid_size: int = 1, |
|
|
|
|
log2_mid_count: int = None, |
|
|
|
|
handle_intermediate_image: Callable[[Image.Image], None] = None, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> Image.Image: |
|
|
|
|
image_tokens = self.generate_image_tokens( |
|
|
|
|
text, |
|
|
|
|
seed, |
|
|
|
|
grid_size, |
|
|
|
|
row_count = 16, |
|
|
|
|
log2_mid_count = log2_mid_count, |
|
|
|
|
handle_intermediate_image = handle_intermediate_image, |
|
|
|
|
is_verbose = is_verbose |
|
|
|
|
log2_mid_count = 0 |
|
|
|
|
image_stream = self.generate_image_stream( |
|
|
|
|
text, |
|
|
|
|
seed, |
|
|
|
|
grid_size, |
|
|
|
|
log2_mid_count, |
|
|
|
|
is_verbose |
|
|
|
|
) |
|
|
|
|
return self.image_from_tokens(grid_size, image_tokens, is_verbose) |
|
|
|
|
return next(image_stream) |