diff --git a/README.md b/README.md index 59990f7..6b9ff6b 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ display(image) Credit to [@hardmaru](https://twitter.com/hardmaru) for the [example](https://twitter.com/hardmaru/status/1544354119527596034) - +``` ### Progressive Outputs diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index 48135bf..694ce4e 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -171,7 +171,7 @@ class MinDalle: return images - def generate_image_stream( + def generate_raw_image_stream( self, text: str, seed: int, @@ -182,7 +182,7 @@ class MinDalle: top_k: int = 256, supercondition_factor: int = 16, is_verbose: bool = False - ) -> Iterator[Image.Image]: + ) -> Iterator[FloatTensor]: image_count = grid_size ** 2 if is_verbose: print("tokenizing text") tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) @@ -249,84 +249,40 @@ class MinDalle: with torch.cuda.amp.autocast(dtype=torch.float32): if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256: - image = self.image_grid_from_tokens( + yield self.image_grid_from_tokens( image_tokens=image_tokens[1:].T, is_seamless=is_seamless, is_verbose=is_verbose ) - image = image.to(torch.uint8).to('cpu').numpy() - yield Image.fromarray(image) + def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]: + image_stream = self.generate_raw_image_stream(*args, **kwargs) + for image in image_stream: + image = image.to(torch.uint8).to('cpu').numpy() + yield Image.fromarray(image) - def generate_image( - self, - text: str, - seed: int = -1, - grid_size: int = 1, - temperature: float = 1, - top_k: int = 1024, - supercondition_factor: int = 16, - is_verbose: bool = False - ) -> Image.Image: + + def generate_images_stream(self, *args, **kwargs) -> Iterator[FloatTensor]: + image_stream = self.generate_raw_image_stream(*args, **kwargs) + for image in image_stream: + grid_size = kwargs['grid_size'] + image = image.view([grid_size * 256, grid_size, 256, 3]) + image = image.transpose(1, 0) + image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3]) + yield image + + + def generate_image(self, *args, **kwargs) -> Image.Image: image_stream = self.generate_image_stream( - text=text, - seed=seed, - grid_size=grid_size, - progressive_outputs=False, - temperature=temperature, - top_k=top_k, - supercondition_factor=supercondition_factor, - is_verbose=is_verbose + *args, **kwargs, + progressive_outputs=False ) return next(image_stream) - # def images_from_image(image: Image.Image) -> FloatTensor: - # pass - - # def generate_images_stream( - # self, - # text: str, - # seed: int, - # grid_size: int, - # progressive_outputs: bool = False, - # temperature: float = 1, - # top_k: int = 256, - # supercondition_factor: int = 16, - # is_verbose: bool = False - # ) -> Iterator[FloatTensor]: - # image_stream = self.generate_image_stream( - # text=text, - # seed=seed, - # image_count=grid_size ** 2, - # progressive_outputs=progressive_outputs, - # is_seamless=False, - # temperature=temperature, - # top_k=top_k, - # supercondition_factor=supercondition_factor, - # is_verbose=is_verbose - # ) - # for image in image_stream: - # yield self.images_from_image(image) - - # def generate_images( - # self, - # text: str, - # seed: int = -1, - # image_count: int = 1, - # temperature: float = 1, - # top_k: int = 1024, - # supercondition_factor: int = 16, - # is_verbose: bool = False - # ) -> FloatTensor: - # images_stream = self.generate_images_stream( - # text=text, - # seed=seed, - # image_count=image_count, - # temperature=temperature, - # progressive_outputs=False, - # top_k=top_k, - # supercondition_factor=supercondition_factor, - # is_verbose=is_verbose - # ) - # return next(images_stream) \ No newline at end of file + def generate_images(self, *args, **kwargs) -> Image.Image: + images_stream = self.generate_images_stream( + *args, **kwargs, + progressive_outputs=False + ) + return next(images_stream) \ No newline at end of file diff --git a/replicate_predictor.py b/replicate_predictor.py index 3bac013..24f4a7f 100644 --- a/replicate_predictor.py +++ b/replicate_predictor.py @@ -20,6 +20,7 @@ class ReplicatePredictor(BasePredictor): text: str = Input(default='Dali painting of WALL·E'), save_as_png: bool = Input(default=False), progressive_outputs: bool = Input(default=True), + seamless: bool = Input(default=False), grid_size: int = Input(ge=1, le=9, default=5), temperature: str = Input( choices=( @@ -45,6 +46,7 @@ class ReplicatePredictor(BasePredictor): seed = -1, grid_size = grid_size, progressive_outputs = progressive_outputs, + is_seamless=seamless, temperature = eval(temperature), supercondition_factor = float(supercondition_factor), top_k = top_k, diff --git a/setup.py b/setup.py index f1e49b2..a21f363 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALL·E)', # long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.3.15', + version='0.3.16', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle', diff --git a/tkinter_ui.py b/tkinter_ui.py index a64d093..2fb0a84 100644 --- a/tkinter_ui.py +++ b/tkinter_ui.py @@ -101,7 +101,7 @@ def generate(): label_image.update() def save(): - final_image.save('out.png') + final_image.save('generated/out.png') frm = ttk.Frame(root, padding=16) frm.grid()