From 798e6ac5a3088671c76168feaddbb88a9f9added Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Sun, 17 Jul 2022 07:33:31 -0400 Subject: [PATCH] optionally tile images in token space --- README.md | 4 +- cog.yaml | 2 +- min_dalle/min_dalle.py | 129 +++++++++++++------------- min_dalle/models/vqgan_detokenizer.py | 94 +++++++++++-------- setup.py | 2 +- tkinter_ui.py | 21 ++++- 6 files changed, 145 insertions(+), 107 deletions(-) diff --git a/README.md b/README.md index 6b9ff6b..59990f7 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ display(image) Credit to [@hardmaru](https://twitter.com/hardmaru) for the [example](https://twitter.com/hardmaru/status/1544354119527596034) -### Saving Individual Images + ### Progressive Outputs diff --git a/cog.yaml b/cog.yaml index f724114..c0b0ce0 100644 --- a/cog.yaml +++ b/cog.yaml @@ -6,7 +6,7 @@ build: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.3.13" + - "min-dalle==0.3.15" run: - pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index f38e499..48135bf 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -156,39 +156,34 @@ class MinDalle: self.detokenizer = self.detokenizer.to(device=self.device) - def images_from_tokens( + def image_grid_from_tokens( self, image_tokens: LongTensor, + is_seamless: bool, is_verbose: bool = False ) -> FloatTensor: if not self.is_reusable: del self.decoder torch.cuda.empty_cache() if not self.is_reusable: self.init_detokenizer() if is_verbose: print("detokenizing image") - images = self.detokenizer.forward(image_tokens).to(torch.uint8) + images = self.detokenizer.forward(is_seamless, image_tokens) if not self.is_reusable: del self.detokenizer return images - def grid_from_images(self, images: FloatTensor) -> Image.Image: - grid_size = int(sqrt(images.shape[0])) - images = images.reshape([grid_size] * 2 + list(images.shape[1:])) - image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2) - image = Image.fromarray(image.to('cpu').numpy()) - return image - - - def generate_images_stream( + def generate_image_stream( self, text: str, seed: int, - image_count: int, + grid_size: int, progressive_outputs: bool = False, + is_seamless: bool = False, temperature: float = 1, top_k: int = 256, supercondition_factor: int = 16, is_verbose: bool = False - ) -> Iterator[FloatTensor]: + ) -> Iterator[Image.Image]: + image_count = grid_size ** 2 if is_verbose: print("tokenizing text") tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) if len(tokens) > self.text_token_count: @@ -254,58 +249,13 @@ class MinDalle: with torch.cuda.amp.autocast(dtype=torch.float32): if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256: - yield self.images_from_tokens( - image_tokens=image_tokens[1:].T, + image = self.image_grid_from_tokens( + image_tokens=image_tokens[1:].T, + is_seamless=is_seamless, is_verbose=is_verbose ) - - - def generate_image_stream( - self, - text: str, - seed: int, - grid_size: int, - progressive_outputs: bool = False, - temperature: float = 1, - top_k: int = 256, - supercondition_factor: int = 16, - is_verbose: bool = False - ) -> Iterator[Image.Image]: - images_stream = self.generate_images_stream( - text=text, - seed=seed, - image_count=grid_size ** 2, - progressive_outputs=progressive_outputs, - temperature=temperature, - top_k=top_k, - supercondition_factor=supercondition_factor, - is_verbose=is_verbose - ) - for images in images_stream: - yield self.grid_from_images(images) - - - def generate_images( - self, - text: str, - seed: int = -1, - image_count: int = 1, - temperature: float = 1, - top_k: int = 1024, - supercondition_factor: int = 16, - is_verbose: bool = False - ) -> FloatTensor: - images_stream = self.generate_images_stream( - text=text, - seed=seed, - image_count=image_count, - temperature=temperature, - progressive_outputs=False, - top_k=top_k, - supercondition_factor=supercondition_factor, - is_verbose=is_verbose - ) - return next(images_stream) + image = image.to(torch.uint8).to('cpu').numpy() + yield Image.fromarray(image) def generate_image( @@ -328,4 +278,55 @@ class MinDalle: supercondition_factor=supercondition_factor, is_verbose=is_verbose ) - return next(image_stream) \ No newline at end of file + return next(image_stream) + + + # def images_from_image(image: Image.Image) -> FloatTensor: + # pass + + # def generate_images_stream( + # self, + # text: str, + # seed: int, + # grid_size: int, + # progressive_outputs: bool = False, + # temperature: float = 1, + # top_k: int = 256, + # supercondition_factor: int = 16, + # is_verbose: bool = False + # ) -> Iterator[FloatTensor]: + # image_stream = self.generate_image_stream( + # text=text, + # seed=seed, + # image_count=grid_size ** 2, + # progressive_outputs=progressive_outputs, + # is_seamless=False, + # temperature=temperature, + # top_k=top_k, + # supercondition_factor=supercondition_factor, + # is_verbose=is_verbose + # ) + # for image in image_stream: + # yield self.images_from_image(image) + + # def generate_images( + # self, + # text: str, + # seed: int = -1, + # image_count: int = 1, + # temperature: float = 1, + # top_k: int = 1024, + # supercondition_factor: int = 16, + # is_verbose: bool = False + # ) -> FloatTensor: + # images_stream = self.generate_images_stream( + # text=text, + # seed=seed, + # image_count=image_count, + # temperature=temperature, + # progressive_outputs=False, + # top_k=top_k, + # supercondition_factor=supercondition_factor, + # is_verbose=is_verbose + # ) + # return next(images_stream) \ No newline at end of file diff --git a/min_dalle/models/vqgan_detokenizer.py b/min_dalle/models/vqgan_detokenizer.py index 5eb7ca0..00013d1 100644 --- a/min_dalle/models/vqgan_detokenizer.py +++ b/min_dalle/models/vqgan_detokenizer.py @@ -1,19 +1,20 @@ import torch +from torch import nn from torch import FloatTensor, LongTensor -from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding +from math import sqrt -class ResnetBlock(Module): +class ResnetBlock(nn.Module): def __init__(self, log2_count_in: int, log2_count_out: int): super().__init__() m, n = 2 ** log2_count_in, 2 ** log2_count_out self.is_middle = m == n - self.norm1 = GroupNorm(2 ** 5, m) - self.conv1 = Conv2d(m, n, 3, padding=1) - self.norm2 = GroupNorm(2 ** 5, n) - self.conv2 = Conv2d(n, n, 3, padding=1) + self.norm1 = nn.GroupNorm(2 ** 5, m) + self.conv1 = nn.Conv2d(m, n, 3, padding=1) + self.norm2 = nn.GroupNorm(2 ** 5, n) + self.conv2 = nn.Conv2d(n, n, 3, padding=1) if not self.is_middle: - self.nin_shortcut = Conv2d(m, n, 1) + self.nin_shortcut = nn.Conv2d(m, n, 1) def forward(self, x: FloatTensor) -> FloatTensor: h = x @@ -28,38 +29,39 @@ class ResnetBlock(Module): return x + h -class AttentionBlock(Module): +class AttentionBlock(nn.Module): def __init__(self): super().__init__() n = 2 ** 9 - self.norm = GroupNorm(2 ** 5, n) - self.q = Conv2d(n, n, 1) - self.k = Conv2d(n, n, 1) - self.v = Conv2d(n, n, 1) - self.proj_out = Conv2d(n, n, 1) + self.norm = nn.GroupNorm(2 ** 5, n) + self.q = nn.Conv2d(n, n, 1) + self.k = nn.Conv2d(n, n, 1) + self.v = nn.Conv2d(n, n, 1) + self.proj_out = nn.Conv2d(n, n, 1) def forward(self, x: FloatTensor) -> FloatTensor: n, m = 2 ** 9, x.shape[0] h = x h = self.norm(h) - q = self.q.forward(h) k = self.k.forward(h) v = self.v.forward(h) - q = q.reshape(m, n, 2 ** 8) + q = self.q.forward(h) + k = k.reshape(m, n, -1) + v = v.reshape(m, n, -1) + q = q.reshape(m, n, -1) q = q.permute(0, 2, 1) - k = k.reshape(m, n, 2 ** 8) w = torch.bmm(q, k) w /= n ** 0.5 w = torch.softmax(w, dim=2) - v = v.reshape(m, n, 2 ** 8) w = w.permute(0, 2, 1) h = torch.bmm(v, w) - h = h.reshape(m, n, 2 ** 4, 2 ** 4) + token_count = int(sqrt(h.shape[-1])) + h = h.reshape(m, n, token_count, token_count) h = self.proj_out.forward(h) return x + h -class MiddleLayer(Module): +class MiddleLayer(nn.Module): def __init__(self): super().__init__() self.block_1 = ResnetBlock(9, 9) @@ -73,12 +75,12 @@ class MiddleLayer(Module): return h -class Upsample(Module): +class Upsample(nn.Module): def __init__(self, log2_count): super().__init__() n = 2 ** log2_count self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2) - self.conv = Conv2d(n, n, 3, padding=1) + self.conv = nn.Conv2d(n, n, 3, padding=1) def forward(self, x: FloatTensor) -> FloatTensor: x = self.upsample.forward(x.to(torch.float32)) @@ -86,7 +88,7 @@ class Upsample(Module): return x -class UpsampleBlock(Module): +class UpsampleBlock(nn.Module): def __init__( self, log2_count_in: int, @@ -97,19 +99,19 @@ class UpsampleBlock(Module): super().__init__() self.has_attention = has_attention self.has_upsample = has_upsample - self.block = ModuleList([ + + self.block = nn.ModuleList([ ResnetBlock(log2_count_in, log2_count_out), ResnetBlock(log2_count_out, log2_count_out), ResnetBlock(log2_count_out, log2_count_out) ]) + if has_attention: - self.attn = ModuleList([ + self.attn = nn.ModuleList([ AttentionBlock(), AttentionBlock(), AttentionBlock() ]) - else: - self.attn = ModuleList() if has_upsample: self.upsample = Upsample(log2_count_out) @@ -125,14 +127,14 @@ class UpsampleBlock(Module): return h -class Decoder(Module): +class Decoder(nn.Module): def __init__(self): super().__init__() - self.conv_in = Conv2d(2 ** 8, 2 ** 9, 3, padding=1) + self.conv_in = nn.Conv2d(2 ** 8, 2 ** 9, 3, padding=1) self.mid = MiddleLayer() - self.up = ModuleList([ + self.up = nn.ModuleList([ UpsampleBlock(7, 7, False, False), UpsampleBlock(8, 7, False, True), UpsampleBlock(8, 8, False, True), @@ -140,8 +142,8 @@ class Decoder(Module): UpsampleBlock(9, 9, True, True) ]) - self.norm_out = GroupNorm(2 ** 5, 2 ** 7) - self.conv_out = Conv2d(2 ** 7, 3, 3, padding=1) + self.norm_out = nn.GroupNorm(2 ** 5, 2 ** 7) + self.conv_out = nn.Conv2d(2 ** 7, 3, 3, padding=1) def forward(self, z: FloatTensor) -> FloatTensor: z = self.conv_in.forward(z) @@ -156,22 +158,40 @@ class Decoder(Module): return z -class VQGanDetokenizer(Module): +class VQGanDetokenizer(nn.Module): def __init__(self): super().__init__() vocab_count, embed_count = 2 ** 14, 2 ** 8 self.vocab_count = vocab_count - self.embedding = Embedding(vocab_count, embed_count) - self.post_quant_conv = Conv2d(embed_count, embed_count, 1) + self.embedding = nn.Embedding(vocab_count, embed_count) + self.post_quant_conv = nn.Conv2d(embed_count, embed_count, 1) self.decoder = Decoder() - def forward(self, z: LongTensor) -> FloatTensor: + def forward(self, is_seamless: bool, z: LongTensor) -> FloatTensor: z.clamp_(0, self.vocab_count - 1) - z = self.embedding.forward(z) - z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8)) + grid_size = int(sqrt(z.shape[0])) + token_count = grid_size * 2 ** 4 + + if is_seamless: + z = z.view([grid_size, grid_size, 2 ** 4, 2 ** 4]) + z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2) + z = z.flatten().unsqueeze(1) + z = self.embedding.forward(z) + z = z.view((1, token_count, token_count, 2 ** 8)) + else: + z = self.embedding.forward(z) + z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8)) + z = z.permute(0, 3, 1, 2).contiguous() z = self.post_quant_conv.forward(z) z = self.decoder.forward(z) z = z.permute(0, 2, 3, 1) z = z.clip(0.0, 1.0) * 255 + + if is_seamless: + z = z[0] + else: + z = z.view([grid_size, grid_size, 2 ** 8, 2 ** 8, 3]) + z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2) + return z diff --git a/setup.py b/setup.py index 96a4199..f1e49b2 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALLĀ·E)', # long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.3.13', + version='0.3.15', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle', diff --git a/tkinter_ui.py b/tkinter_ui.py index 72155d6..a64d093 100644 --- a/tkinter_ui.py +++ b/tkinter_ui.py @@ -57,6 +57,7 @@ sv_prompt = tkinter.StringVar(value="artificial intelligence") sv_temperature = tkinter.StringVar(value="1") sv_topk = tkinter.StringVar(value="128") sv_supercond = tkinter.StringVar(value="16") +bv_seamless = tkinter.BooleanVar(value=False) def generate(): # check fields @@ -75,6 +76,10 @@ def generate(): except: sv_supercond.set("ERROR") return + try: + is_seamless = bool(bv_seamless.get()) + except: + return # and continue global label_image_content image_stream = model.generate_image_stream( @@ -82,16 +87,22 @@ def generate(): grid_size=2, seed=-1, progressive_outputs=True, + is_seamless=is_seamless, temperature=temperature, top_k=topk, supercondition_factor=supercond, is_verbose=True ) for image in image_stream: + global final_image + final_image = image label_image_content = PIL.ImageTk.PhotoImage(image) label_image.configure(image=label_image_content) label_image.update() +def save(): + final_image.save('out.png') + frm = ttk.Frame(root, padding=16) frm.grid() @@ -124,8 +135,14 @@ ttk.Label(props, text="Supercondition Factor:").grid(column=0, row=6) ttk.Entry(props, textvariable=sv_supercond).grid(column=1, row=6) # ttk.Label(props, image=padding_image).grid(column=0, row=7) +# seamless +ttk.Label(props, text="Seamless:").grid(column=0, row=8) +ttk.Checkbutton(props, variable=bv_seamless).grid(column=1, row=8) +# +ttk.Label(props, image=padding_image).grid(column=0, row=9) # buttons -ttk.Button(props, text="Generate", command=generate).grid(column=0, row=8) -ttk.Button(props, text="Quit", command=root.destroy).grid(column=1, row=8) +ttk.Button(props, text="Generate", command=generate).grid(column=0, row=10) +ttk.Button(props, text="Quit", command=root.destroy).grid(column=1, row=10) +ttk.Button(props, text="Save", command=save).grid(column=2, row=10) root.mainloop() \ No newline at end of file