|
|
|
@ -165,7 +165,7 @@ class MinDalle: |
|
|
|
|
seed: int, |
|
|
|
|
grid_size: int, |
|
|
|
|
row_count: int, |
|
|
|
|
mid_count: int = None, |
|
|
|
|
log2_mid_count: int = 0, |
|
|
|
|
handle_intermediate_image: Callable[[int, Image.Image], None] = None, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> LongTensor: |
|
|
|
@ -206,8 +206,8 @@ class MinDalle: |
|
|
|
|
attention_state, |
|
|
|
|
image_tokens |
|
|
|
|
) |
|
|
|
|
if mid_count is not None: |
|
|
|
|
if ((row_index + 1) * mid_count) % row_count == 0: |
|
|
|
|
if handle_intermediate_image is not None: |
|
|
|
|
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: |
|
|
|
|
tokens = image_tokens[:, 1:] |
|
|
|
|
image = self.image_from_tokens(grid_size, tokens, is_verbose) |
|
|
|
|
handle_intermediate_image(row_index, image) |
|
|
|
@ -220,7 +220,7 @@ class MinDalle: |
|
|
|
|
text: str, |
|
|
|
|
seed: int = -1, |
|
|
|
|
grid_size: int = 1, |
|
|
|
|
mid_count: int = None, |
|
|
|
|
log2_mid_count: int = None, |
|
|
|
|
handle_intermediate_image: Callable[[Image.Image], None] = None, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> Image.Image: |
|
|
|
@ -229,7 +229,7 @@ class MinDalle: |
|
|
|
|
seed, |
|
|
|
|
grid_size, |
|
|
|
|
row_count = 16, |
|
|
|
|
mid_count = mid_count, |
|
|
|
|
log2_mid_count = log2_mid_count, |
|
|
|
|
handle_intermediate_image = handle_intermediate_image, |
|
|
|
|
is_verbose = is_verbose |
|
|
|
|
) |
|
|
|
|