update replicate, remove unused examples

This commit is contained in:
Brett Kuprel 2022-07-13 09:31:37 -04:00
parent 4e7b3b2caf
commit 4350c643f5
12 changed files with 28 additions and 13 deletions

1
.gitignore vendored
View File

@ -16,3 +16,4 @@ dist
build build
README README
.cog .cog
cog

Binary file not shown.

Before

Width:  |  Height:  |  Size: 45 KiB

After

Width:  |  Height:  |  Size: 23 KiB

BIN
examples/funeral.jpg vendored

Binary file not shown.

Before

Width:  |  Height:  |  Size: 414 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 320 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 190 KiB

BIN
examples/ironman.jpg vendored

Binary file not shown.

Before

Width:  |  Height:  |  Size: 504 KiB

BIN
examples/jesus.jpg vendored

Binary file not shown.

Before

Width:  |  Height:  |  Size: 206 KiB

BIN
examples/panda_tophat_high_temp.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 294 KiB

BIN
examples/panda_tophat_low_temp.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 214 KiB

BIN
examples/yoda.jpg vendored

Binary file not shown.

Before

Width:  |  Height:  |  Size: 250 KiB

4
min_dalle.ipynb vendored
View File

@ -192,12 +192,12 @@
"%%time\n", "%%time\n",
"\n", "\n",
"text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n", "text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n",
"intermediate_outputs = True #@param {type:\"boolean\"}\n", "progressive_outputs = True #@param {type:\"boolean\"}\n",
"grid_size = 5 #@param {type:\"integer\"}\n", "grid_size = 5 #@param {type:\"integer\"}\n",
"temperature = 2 #@param {type:\"slider\", min:0.01, max:3, step:0.01}\n", "temperature = 2 #@param {type:\"slider\", min:0.01, max:3, step:0.01}\n",
"supercondition_factor = 16 #@param {type:\"number\"}\n", "supercondition_factor = 16 #@param {type:\"number\"}\n",
"top_k = 256 #@param {type:\"integer\"}\n", "top_k = 256 #@param {type:\"integer\"}\n",
"log2_mid_count = 3 if intermediate_outputs else 0\n", "log2_mid_count = 3 if progressive_outputs else 0\n",
"\n", "\n",
"image_stream = model.generate_image_stream(\n", "image_stream = model.generate_image_stream(\n",
" text=text,\n", " text=text,\n",

View File

@ -6,7 +6,6 @@ from cog import BasePredictor, Path, Input
torch.backends.cudnn.deterministic = False torch.backends.cudnn.deterministic = False
class ReplicatePredictor(BasePredictor): class ReplicatePredictor(BasePredictor):
def setup(self): def setup(self):
self.model = MinDalle( self.model = MinDalle(
@ -18,22 +17,37 @@ class ReplicatePredictor(BasePredictor):
def predict( def predict(
self, self,
text: str = Input(default='Dali painting of WALL·E'), text: str = Input(default='Dali painting of WALL·E'),
output_png: bool = Input(default=False), save_as_png: bool = Input(default=False),
intermediate_outputs: bool = Input(default=True), progressive_outputs: bool = Input(default=True),
grid_size: int = Input(ge=1, le=9, default=5), grid_size: int = Input(ge=1, le=9, default=5),
log2_temperature: float = Input(ge=-3, le=3, default=2), temperature: str = Input(
log2_top_k: int = Input(ge=0, le=14, default=4), choices=(
log2_supercondition_factor: float = Input(ge=2, le=6, default=4) ['1/{}'.format(2 ** i) for i in range(4, 0, -1)] +
[str(2 ** i) for i in range(5)]
),
default='4',
description='Advanced Setting, see Readme below if interested.'
),
top_k: int = Input(
choices=[2 ** i for i in range(15)],
default=64,
description='Advanced Setting, see Readme below if interested.'
),
supercondition_factor: int = Input(
choices=[2 ** i for i in range(2, 7)],
default=16,
description='Advanced Setting, see Readme below if interested.'
)
) -> Iterator[Path]: ) -> Iterator[Path]:
log2_mid_count = 3 if intermediate_outputs else 0 log2_mid_count = 3 if progressive_outputs else 0
image_stream = self.model.generate_image_stream( image_stream = self.model.generate_image_stream(
text = text, text = text,
seed = -1, seed = -1,
grid_size = grid_size, grid_size = grid_size,
log2_mid_count = log2_mid_count, log2_mid_count = log2_mid_count,
temperature = 2 ** log2_temperature, temperature = eval(temperature),
supercondition_factor = 2 ** log2_supercondition_factor, supercondition_factor = float(supercondition_factor),
top_k = 2 ** log2_top_k, top_k = top_k,
is_verbose = True is_verbose = True
) )
@ -41,7 +55,7 @@ class ReplicatePredictor(BasePredictor):
path = Path(tempfile.mkdtemp()) path = Path(tempfile.mkdtemp())
for image in image_stream: for image in image_stream:
i += 1 i += 1
ext = 'png' if i == 2 ** log2_mid_count and output_png else 'jpg' ext = 'png' if i == 2 ** log2_mid_count and save_as_png else 'jpg'
image_path = path / 'min-dalle-iter-{}.{}'.format(i, ext) image_path = path / 'min-dalle-iter-{}.{}'.format(i, ext)
image.save(str(image_path)) image.save(str(image_path))
yield image_path yield image_path