diff --git a/.gitignore b/.gitignore index 4a56a38..ac3e6a8 100755 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ dist build README .cog +cog diff --git a/examples/artificial_intelligence.jpg b/examples/artificial_intelligence.jpg index d80950c..60efef9 100644 Binary files a/examples/artificial_intelligence.jpg and b/examples/artificial_intelligence.jpg differ diff --git a/examples/funeral.jpg b/examples/funeral.jpg deleted file mode 100644 index 978c2e4..0000000 Binary files a/examples/funeral.jpg and /dev/null differ diff --git a/examples/godzilla_trial.jpg b/examples/godzilla_trial.jpg deleted file mode 100644 index e3d397c..0000000 Binary files a/examples/godzilla_trial.jpg and /dev/null differ diff --git a/examples/gollum_trailcam.jpg b/examples/gollum_trailcam.jpg deleted file mode 100644 index c0a2435..0000000 Binary files a/examples/gollum_trailcam.jpg and /dev/null differ diff --git a/examples/ironman.jpg b/examples/ironman.jpg deleted file mode 100644 index 0134078..0000000 Binary files a/examples/ironman.jpg and /dev/null differ diff --git a/examples/jesus.jpg b/examples/jesus.jpg deleted file mode 100644 index c0c1fce..0000000 Binary files a/examples/jesus.jpg and /dev/null differ diff --git a/examples/panda_tophat_high_temp.jpg b/examples/panda_tophat_high_temp.jpg new file mode 100644 index 0000000..c129266 Binary files /dev/null and b/examples/panda_tophat_high_temp.jpg differ diff --git a/examples/panda_tophat_low_temp.jpg b/examples/panda_tophat_low_temp.jpg new file mode 100644 index 0000000..ee29bce Binary files /dev/null and b/examples/panda_tophat_low_temp.jpg differ diff --git a/examples/yoda.jpg b/examples/yoda.jpg deleted file mode 100644 index 75bafd3..0000000 Binary files a/examples/yoda.jpg and /dev/null differ diff --git a/min_dalle.ipynb b/min_dalle.ipynb index 4e922a0..60c70a2 100644 --- a/min_dalle.ipynb +++ b/min_dalle.ipynb @@ -192,12 +192,12 @@ "%%time\n", "\n", "text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n", - "intermediate_outputs = True #@param {type:\"boolean\"}\n", + "progressive_outputs = True #@param {type:\"boolean\"}\n", "grid_size = 5 #@param {type:\"integer\"}\n", "temperature = 2 #@param {type:\"slider\", min:0.01, max:3, step:0.01}\n", "supercondition_factor = 16 #@param {type:\"number\"}\n", "top_k = 256 #@param {type:\"integer\"}\n", - "log2_mid_count = 3 if intermediate_outputs else 0\n", + "log2_mid_count = 3 if progressive_outputs else 0\n", "\n", "image_stream = model.generate_image_stream(\n", " text=text,\n", diff --git a/replicate_predictor.py b/replicate_predictor.py index e46be99..c5e7f1b 100644 --- a/replicate_predictor.py +++ b/replicate_predictor.py @@ -6,7 +6,6 @@ from cog import BasePredictor, Path, Input torch.backends.cudnn.deterministic = False - class ReplicatePredictor(BasePredictor): def setup(self): self.model = MinDalle( @@ -18,22 +17,37 @@ class ReplicatePredictor(BasePredictor): def predict( self, text: str = Input(default='Dali painting of WALL·E'), - output_png: bool = Input(default=False), - intermediate_outputs: bool = Input(default=True), + save_as_png: bool = Input(default=False), + progressive_outputs: bool = Input(default=True), grid_size: int = Input(ge=1, le=9, default=5), - log2_temperature: float = Input(ge=-3, le=3, default=2), - log2_top_k: int = Input(ge=0, le=14, default=4), - log2_supercondition_factor: float = Input(ge=2, le=6, default=4) + temperature: str = Input( + choices=( + ['1/{}'.format(2 ** i) for i in range(4, 0, -1)] + + [str(2 ** i) for i in range(5)] + ), + default='4', + description='Advanced Setting, see Readme below if interested.' + ), + top_k: int = Input( + choices=[2 ** i for i in range(15)], + default=64, + description='Advanced Setting, see Readme below if interested.' + ), + supercondition_factor: int = Input( + choices=[2 ** i for i in range(2, 7)], + default=16, + description='Advanced Setting, see Readme below if interested.' + ) ) -> Iterator[Path]: - log2_mid_count = 3 if intermediate_outputs else 0 + log2_mid_count = 3 if progressive_outputs else 0 image_stream = self.model.generate_image_stream( text = text, seed = -1, grid_size = grid_size, log2_mid_count = log2_mid_count, - temperature = 2 ** log2_temperature, - supercondition_factor = 2 ** log2_supercondition_factor, - top_k = 2 ** log2_top_k, + temperature = eval(temperature), + supercondition_factor = float(supercondition_factor), + top_k = top_k, is_verbose = True ) @@ -41,7 +55,7 @@ class ReplicatePredictor(BasePredictor): path = Path(tempfile.mkdtemp()) for image in image_stream: i += 1 - ext = 'png' if i == 2 ** log2_mid_count and output_png else 'jpg' + ext = 'png' if i == 2 ** log2_mid_count and save_as_png else 'jpg' image_path = path / 'min-dalle-iter-{}.{}'.format(i, ext) image.save(str(image_path)) yield image_path \ No newline at end of file