update replicate to match colab

This commit is contained in:
Brett Kuprel 2022-07-04 20:36:23 -04:00
parent 92b4dc3333
commit b156e5c8c4

View File

@ -1,6 +1,7 @@
from min_dalle import MinDalle from min_dalle import MinDalle
import tempfile import tempfile
from typing import Iterator from typing import Iterator
from math import log2
from cog import BasePredictor, Path, Input from cog import BasePredictor, Path, Input
@ -24,22 +25,21 @@ class Predictor(BasePredictor):
description='Set the seed to a positive number for reproducible results', description='Set the seed to a positive number for reproducible results',
default=-1 default=-1
), ),
log2_intermediate_image_count: int = Input( intermediate_image_count: int = Input(
description='Set the log2 number of intermediate images to show', description='Set the number of intermediate images to show',
ge=0, choices=[1, 2, 4, 8, 16],
le=4, default=8
default=3
), ),
) -> Iterator[Path]: ) -> Iterator[Path]:
image_stream = self.model.generate_image_stream( image_stream = self.model.generate_image_stream(
text, text,
seed, seed,
grid_size=grid_size, grid_size=grid_size,
log2_mid_count=log2_intermediate_image_count, log2_mid_count=log2(intermediate_image_count),
is_verbose=True is_verbose=True
) )
for image in image_stream: for image in image_stream:
out_path = Path(tempfile.mkdtemp()) / 'output.jpg' path = Path(tempfile.mkdtemp()) / 'output.jpg'
image.save(str(out_path)) image.save(str(path))
yield out_path yield path