|
|
|
@ -171,7 +171,7 @@ class MinDalle: |
|
|
|
|
return images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image_stream( |
|
|
|
|
def generate_raw_image_stream( |
|
|
|
|
self, |
|
|
|
|
text: str, |
|
|
|
|
seed: int, |
|
|
|
@ -182,7 +182,7 @@ class MinDalle: |
|
|
|
|
top_k: int = 256, |
|
|
|
|
supercondition_factor: int = 16, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> Iterator[Image.Image]: |
|
|
|
|
) -> Iterator[FloatTensor]: |
|
|
|
|
image_count = grid_size ** 2 |
|
|
|
|
if is_verbose: print("tokenizing text") |
|
|
|
|
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) |
|
|
|
@ -249,84 +249,40 @@ class MinDalle: |
|
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(dtype=torch.float32): |
|
|
|
|
if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256: |
|
|
|
|
image = self.image_grid_from_tokens( |
|
|
|
|
yield self.image_grid_from_tokens( |
|
|
|
|
image_tokens=image_tokens[1:].T, |
|
|
|
|
is_seamless=is_seamless, |
|
|
|
|
is_verbose=is_verbose |
|
|
|
|
) |
|
|
|
|
image = image.to(torch.uint8).to('cpu').numpy() |
|
|
|
|
yield Image.fromarray(image) |
|
|
|
|
|
|
|
|
|
def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]: |
|
|
|
|
image_stream = self.generate_raw_image_stream(*args, **kwargs) |
|
|
|
|
for image in image_stream: |
|
|
|
|
image = image.to(torch.uint8).to('cpu').numpy() |
|
|
|
|
yield Image.fromarray(image) |
|
|
|
|
|
|
|
|
|
def generate_image( |
|
|
|
|
self, |
|
|
|
|
text: str, |
|
|
|
|
seed: int = -1, |
|
|
|
|
grid_size: int = 1, |
|
|
|
|
temperature: float = 1, |
|
|
|
|
top_k: int = 1024, |
|
|
|
|
supercondition_factor: int = 16, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> Image.Image: |
|
|
|
|
|
|
|
|
|
def generate_images_stream(self, *args, **kwargs) -> Iterator[FloatTensor]: |
|
|
|
|
image_stream = self.generate_raw_image_stream(*args, **kwargs) |
|
|
|
|
for image in image_stream: |
|
|
|
|
grid_size = kwargs['grid_size'] |
|
|
|
|
image = image.view([grid_size * 256, grid_size, 256, 3]) |
|
|
|
|
image = image.transpose(1, 0) |
|
|
|
|
image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3]) |
|
|
|
|
yield image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image(self, *args, **kwargs) -> Image.Image: |
|
|
|
|
image_stream = self.generate_image_stream( |
|
|
|
|
text=text, |
|
|
|
|
seed=seed, |
|
|
|
|
grid_size=grid_size, |
|
|
|
|
progressive_outputs=False, |
|
|
|
|
temperature=temperature, |
|
|
|
|
top_k=top_k, |
|
|
|
|
supercondition_factor=supercondition_factor, |
|
|
|
|
is_verbose=is_verbose |
|
|
|
|
*args, **kwargs, |
|
|
|
|
progressive_outputs=False |
|
|
|
|
) |
|
|
|
|
return next(image_stream) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def images_from_image(image: Image.Image) -> FloatTensor: |
|
|
|
|
# pass |
|
|
|
|
|
|
|
|
|
# def generate_images_stream( |
|
|
|
|
# self, |
|
|
|
|
# text: str, |
|
|
|
|
# seed: int, |
|
|
|
|
# grid_size: int, |
|
|
|
|
# progressive_outputs: bool = False, |
|
|
|
|
# temperature: float = 1, |
|
|
|
|
# top_k: int = 256, |
|
|
|
|
# supercondition_factor: int = 16, |
|
|
|
|
# is_verbose: bool = False |
|
|
|
|
# ) -> Iterator[FloatTensor]: |
|
|
|
|
# image_stream = self.generate_image_stream( |
|
|
|
|
# text=text, |
|
|
|
|
# seed=seed, |
|
|
|
|
# image_count=grid_size ** 2, |
|
|
|
|
# progressive_outputs=progressive_outputs, |
|
|
|
|
# is_seamless=False, |
|
|
|
|
# temperature=temperature, |
|
|
|
|
# top_k=top_k, |
|
|
|
|
# supercondition_factor=supercondition_factor, |
|
|
|
|
# is_verbose=is_verbose |
|
|
|
|
# ) |
|
|
|
|
# for image in image_stream: |
|
|
|
|
# yield self.images_from_image(image) |
|
|
|
|
|
|
|
|
|
# def generate_images( |
|
|
|
|
# self, |
|
|
|
|
# text: str, |
|
|
|
|
# seed: int = -1, |
|
|
|
|
# image_count: int = 1, |
|
|
|
|
# temperature: float = 1, |
|
|
|
|
# top_k: int = 1024, |
|
|
|
|
# supercondition_factor: int = 16, |
|
|
|
|
# is_verbose: bool = False |
|
|
|
|
# ) -> FloatTensor: |
|
|
|
|
# images_stream = self.generate_images_stream( |
|
|
|
|
# text=text, |
|
|
|
|
# seed=seed, |
|
|
|
|
# image_count=image_count, |
|
|
|
|
# temperature=temperature, |
|
|
|
|
# progressive_outputs=False, |
|
|
|
|
# top_k=top_k, |
|
|
|
|
# supercondition_factor=supercondition_factor, |
|
|
|
|
# is_verbose=is_verbose |
|
|
|
|
# ) |
|
|
|
|
# return next(images_stream) |
|
|
|
|
def generate_images(self, *args, **kwargs) -> Image.Image: |
|
|
|
|
images_stream = self.generate_images_stream( |
|
|
|
|
*args, **kwargs, |
|
|
|
|
progressive_outputs=False |
|
|
|
|
) |
|
|
|
|
return next(images_stream) |