log2_mid_count

This commit is contained in:
Brett Kuprel 2022-07-04 17:27:23 -04:00
parent 059d8f3448
commit 1702d3c439
3 changed files with 7 additions and 7 deletions

View File

@ -165,7 +165,7 @@ class MinDalle:
seed: int, seed: int,
grid_size: int, grid_size: int,
row_count: int, row_count: int,
mid_count: int = None, log2_mid_count: int = 0,
handle_intermediate_image: Callable[[int, Image.Image], None] = None, handle_intermediate_image: Callable[[int, Image.Image], None] = None,
is_verbose: bool = False is_verbose: bool = False
) -> LongTensor: ) -> LongTensor:
@ -206,8 +206,8 @@ class MinDalle:
attention_state, attention_state,
image_tokens image_tokens
) )
if mid_count is not None: if handle_intermediate_image is not None:
if ((row_index + 1) * mid_count) % row_count == 0: if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
tokens = image_tokens[:, 1:] tokens = image_tokens[:, 1:]
image = self.image_from_tokens(grid_size, tokens, is_verbose) image = self.image_from_tokens(grid_size, tokens, is_verbose)
handle_intermediate_image(row_index, image) handle_intermediate_image(row_index, image)
@ -220,7 +220,7 @@ class MinDalle:
text: str, text: str,
seed: int = -1, seed: int = -1,
grid_size: int = 1, grid_size: int = 1,
mid_count: int = None, log2_mid_count: int = None,
handle_intermediate_image: Callable[[Image.Image], None] = None, handle_intermediate_image: Callable[[Image.Image], None] = None,
is_verbose: bool = False is_verbose: bool = False
) -> Image.Image: ) -> Image.Image:
@ -229,7 +229,7 @@ class MinDalle:
seed, seed,
grid_size, grid_size,
row_count = 16, row_count = 16,
mid_count = mid_count, log2_mid_count = log2_mid_count,
handle_intermediate_image = handle_intermediate_image, handle_intermediate_image = handle_intermediate_image,
is_verbose = is_verbose is_verbose = is_verbose
) )

View File

@ -5,7 +5,7 @@ setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
long_description=(Path(__file__).parent / "README.rst").read_text(), long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.2.21', version='0.2.22',
author='Brett Kuprel', author='Brett Kuprel',
author_email='brkuprel@gmail.com', author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle', url='https://github.com/kuprel/min-dalle',