diff --git a/cogrun.py b/cogrun.py index cfbaf2a..2b044da 100644 --- a/cogrun.py +++ b/cogrun.py @@ -1,6 +1,7 @@ from min_dalle import MinDalle import tempfile from typing import Iterator +from math import log2 from cog import BasePredictor, Path, Input @@ -24,22 +25,21 @@ class Predictor(BasePredictor): description='Set the seed to a positive number for reproducible results', default=-1 ), - log2_intermediate_image_count: int = Input( - description='Set the log2 number of intermediate images to show', - ge=0, - le=4, - default=3 + intermediate_image_count: int = Input( + description='Set the number of intermediate images to show', + choices=[1, 2, 4, 8, 16], + default=8 ), ) -> Iterator[Path]: image_stream = self.model.generate_image_stream( text, seed, grid_size=grid_size, - log2_mid_count=log2_intermediate_image_count, + log2_mid_count=log2(intermediate_image_count), is_verbose=True ) for image in image_stream: - out_path = Path(tempfile.mkdtemp()) / 'output.jpg' - image.save(str(out_path)) - yield out_path \ No newline at end of file + path = Path(tempfile.mkdtemp()) / 'output.jpg' + image.save(str(path)) + yield path \ No newline at end of file