diff --git a/replicate_predictor.py b/replicate_predictor.py index aa79916..c7143a0 100644 --- a/replicate_predictor.py +++ b/replicate_predictor.py @@ -13,26 +13,12 @@ class ReplicatePredictor(BasePredictor): def predict( self, - text: str = Input( - description='For long prompts, only the first 64 tokens will be used to generate the image.', - default='Dali painting of WALL·E' - ), - intermediate_outputs: bool = Input( - description='Whether to show intermediate outputs while running. This adds less than a second to the run time.', - default=True - ), - grid_size: int = Input( - description='Size of the image grid. 5x5 takes around 16 seconds, 8x8 takes around 36 seconds', - ge=1, - le=8, - default=4 - ), - temperature: float = Input( - description='A higher temperature results in more variety.', - ge=0.01, - le=10, - default=2 - ), + text: str = Input(default='Dali painting of WALL·E'), + intermediate_outputs: bool = Input(default=True), + grid_size: int = Input(ge=1, le=9, default=5), + log2_temperature: float = Input(ge=-3, le=3, default=1), + log2_top_k: int = Input(ge=0, le=14, default=7), + log2_supercondition_factor: int = Input(ge=2, le=6, default=4) ) -> Iterator[Path]: try: image_stream = self.model.generate_image_stream( @@ -40,9 +26,9 @@ class ReplicatePredictor(BasePredictor): seed = -1, grid_size = grid_size, log2_mid_count = 3 if intermediate_outputs else 0, - temperature = temperature, - supercondition_factor = 2 ** 4, - top_k = 2 ** 8, + temperature = 2 ** log2_temperature, + supercondition_factor = 2 ** log2_supercondition_factor, + top_k = 2 ** log2_top_k, is_verbose = True )