|
|
|
@ -177,8 +177,9 @@ class MinDalle: |
|
|
|
|
seed: int, |
|
|
|
|
image_count: int, |
|
|
|
|
log2_mid_count: int, |
|
|
|
|
log2_k: int = 6, |
|
|
|
|
log2_supercondition_factor: int = 3, |
|
|
|
|
temperature: float = 1, |
|
|
|
|
top_k: int = 256, |
|
|
|
|
supercondition_factor: int = 16, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> Iterator[FloatTensor]: |
|
|
|
|
assert(log2_mid_count in range(5)) |
|
|
|
@ -206,10 +207,10 @@ class MinDalle: |
|
|
|
|
with torch.cuda.amp.autocast(dtype=self.dtype): |
|
|
|
|
encoder_state, attention_mask, attention_state, image_tokens = ( |
|
|
|
|
self.decoder.decode_initial( |
|
|
|
|
seed, |
|
|
|
|
image_count, |
|
|
|
|
text_tokens, |
|
|
|
|
encoder_state |
|
|
|
|
seed=seed, |
|
|
|
|
image_count=image_count, |
|
|
|
|
text_tokens=text_tokens, |
|
|
|
|
encoder_state=encoder_state |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
@ -220,12 +221,13 @@ class MinDalle: |
|
|
|
|
with torch.cuda.amp.autocast(dtype=self.dtype): |
|
|
|
|
attention_state, image_tokens = self.decoder.decode_row( |
|
|
|
|
row_index, |
|
|
|
|
log2_k, |
|
|
|
|
log2_supercondition_factor, |
|
|
|
|
encoder_state, |
|
|
|
|
attention_mask, |
|
|
|
|
attention_state, |
|
|
|
|
image_tokens |
|
|
|
|
temperature=temperature, |
|
|
|
|
top_k=top_k, |
|
|
|
|
supercondition_factor=supercondition_factor, |
|
|
|
|
encoder_state=encoder_state, |
|
|
|
|
attention_mask=attention_mask, |
|
|
|
|
attention_state=attention_state, |
|
|
|
|
image_tokens_sequence=image_tokens |
|
|
|
|
) |
|
|
|
|
with torch.cuda.amp.autocast(dtype=torch.float32): |
|
|
|
|
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: |
|
|
|
@ -240,18 +242,20 @@ class MinDalle: |
|
|
|
|
seed: int, |
|
|
|
|
grid_size: int, |
|
|
|
|
log2_mid_count: int, |
|
|
|
|
log2_k: int = 6, |
|
|
|
|
log2_supercondition_factor: int = 3, |
|
|
|
|
temperature: float = 1, |
|
|
|
|
top_k: int = 256, |
|
|
|
|
supercondition_factor: int = 16, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> Iterator[Image.Image]: |
|
|
|
|
images_stream = self.generate_images_stream( |
|
|
|
|
text, |
|
|
|
|
seed, |
|
|
|
|
grid_size ** 2, |
|
|
|
|
log2_mid_count, |
|
|
|
|
log2_k, |
|
|
|
|
log2_supercondition_factor, |
|
|
|
|
is_verbose |
|
|
|
|
text=text, |
|
|
|
|
seed=seed, |
|
|
|
|
image_count=grid_size ** 2, |
|
|
|
|
log2_mid_count=log2_mid_count, |
|
|
|
|
temperature=temperature, |
|
|
|
|
top_k=top_k, |
|
|
|
|
supercondition_factor=supercondition_factor, |
|
|
|
|
is_verbose=is_verbose |
|
|
|
|
) |
|
|
|
|
for images in images_stream: |
|
|
|
|
yield self.grid_from_images(images) |
|
|
|
@ -262,19 +266,21 @@ class MinDalle: |
|
|
|
|
text: str, |
|
|
|
|
seed: int = -1, |
|
|
|
|
image_count: int = 1, |
|
|
|
|
log2_k: int = 6, |
|
|
|
|
log2_supercondition_factor: int = 3, |
|
|
|
|
temperature: float = 1, |
|
|
|
|
top_k: int = 1024, |
|
|
|
|
supercondition_factor: int = 16, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> FloatTensor: |
|
|
|
|
log2_mid_count = 0 |
|
|
|
|
images_stream = self.generate_images_stream( |
|
|
|
|
text, |
|
|
|
|
seed, |
|
|
|
|
image_count, |
|
|
|
|
log2_mid_count, |
|
|
|
|
log2_k, |
|
|
|
|
log2_supercondition_factor, |
|
|
|
|
is_verbose |
|
|
|
|
text=text, |
|
|
|
|
seed=seed, |
|
|
|
|
image_count=image_count, |
|
|
|
|
temperature=temperature, |
|
|
|
|
log2_mid_count=log2_mid_count, |
|
|
|
|
top_k=top_k, |
|
|
|
|
supercondition_factor=supercondition_factor, |
|
|
|
|
is_verbose=is_verbose |
|
|
|
|
) |
|
|
|
|
return next(images_stream) |
|
|
|
|
|
|
|
|
@ -284,18 +290,20 @@ class MinDalle: |
|
|
|
|
text: str, |
|
|
|
|
seed: int = -1, |
|
|
|
|
grid_size: int = 1, |
|
|
|
|
log2_k: int = 6, |
|
|
|
|
log2_supercondition_factor: int = 3, |
|
|
|
|
temperature: float = 1, |
|
|
|
|
top_k: int = 1024, |
|
|
|
|
supercondition_factor: int = 16, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> Image.Image: |
|
|
|
|
log2_mid_count = 0 |
|
|
|
|
image_stream = self.generate_image_stream( |
|
|
|
|
text, |
|
|
|
|
seed, |
|
|
|
|
grid_size, |
|
|
|
|
log2_mid_count, |
|
|
|
|
log2_k, |
|
|
|
|
log2_supercondition_factor, |
|
|
|
|
is_verbose |
|
|
|
|
text=text, |
|
|
|
|
seed=seed, |
|
|
|
|
grid_size=grid_size, |
|
|
|
|
log2_mid_count=log2_mid_count, |
|
|
|
|
temperature=temperature, |
|
|
|
|
top_k=top_k, |
|
|
|
|
supercondition_factor=supercondition_factor, |
|
|
|
|
is_verbose=is_verbose |
|
|
|
|
) |
|
|
|
|
return next(image_stream) |