From a2dca416238eb63b88fa306ffc633c0c1b954ea0 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Fri, 15 Jul 2022 09:00:26 -0400 Subject: [PATCH] update tkinter ui --- min_dalle/models/dalle_bart_decoder.py | 76 +++++++++++++------------- ui.py => tkinter_ui.py | 29 +++++----- 2 files changed, 51 insertions(+), 54 deletions(-) rename ui.py => tkinter_ui.py (90%) diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index 31b0154..43d3fc8 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -138,11 +138,46 @@ class DalleBartDecoder(nn.Module): self.start_token = self.start_token.cuda() + def decode_initial( + self, + seed: int, + image_count: int, + text_tokens: LongTensor, + encoder_state: FloatTensor + ) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]: + expanded_indices = [0] * image_count + [1] * image_count + text_tokens = text_tokens[expanded_indices] + encoder_state = encoder_state[expanded_indices] + attention_mask = text_tokens.not_equal(1) + + attention_state_shape = ( + self.layer_count, + image_count * 4, + IMAGE_TOKEN_COUNT, + self.embed_count + ) + attention_state = torch.zeros(attention_state_shape) + image_tokens_sequence = torch.full( + (image_count, IMAGE_TOKEN_COUNT + 1), + BLANK_TOKEN, + dtype=torch.long + ) + if torch.cuda.is_available(): + attention_state = attention_state.cuda() + image_tokens_sequence = image_tokens_sequence.cuda() + + image_tokens_sequence[:, 0] = self.start_token[0] + + if seed > 0: torch.manual_seed(seed) + + return encoder_state, attention_mask, attention_state, image_tokens_sequence + + def decode_step( self, temperature: float, top_k: int, - supercondition_factor: int, + supercondition_factor: float, attention_mask: BoolTensor, encoder_state: FloatTensor, attention_state: FloatTensor, @@ -187,7 +222,7 @@ class DalleBartDecoder(nn.Module): row_index: int, temperature: float, top_k: int, - supercondition_factor: int, + supercondition_factor: float, encoder_state: FloatTensor, attention_mask: BoolTensor, attention_state: FloatTensor, @@ -207,39 +242,4 @@ class DalleBartDecoder(nn.Module): ) image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0] - return attention_state, image_tokens_sequence - - - def decode_initial( - self, - seed: int, - image_count: int, - text_tokens: LongTensor, - encoder_state: FloatTensor - ) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]: - expanded_indices = [0] * image_count + [1] * image_count - text_tokens = text_tokens[expanded_indices] - encoder_state = encoder_state[expanded_indices] - attention_mask = text_tokens.not_equal(1) - - attention_state_shape = ( - self.layer_count, - image_count * 4, - IMAGE_TOKEN_COUNT, - self.embed_count - ) - attention_state = torch.zeros(attention_state_shape) - image_tokens_sequence = torch.full( - (image_count, IMAGE_TOKEN_COUNT + 1), - BLANK_TOKEN, - dtype=torch.long - ) - if torch.cuda.is_available(): - attention_state = attention_state.cuda() - image_tokens_sequence = image_tokens_sequence.cuda() - - image_tokens_sequence[:, 0] = self.start_token[0] - - if seed > 0: torch.manual_seed(seed) - - return encoder_state, attention_mask, attention_state, image_tokens_sequence \ No newline at end of file + return attention_state, image_tokens_sequence \ No newline at end of file diff --git a/ui.py b/tkinter_ui.py similarity index 90% rename from ui.py rename to tkinter_ui.py index a67879f..06c2551 100644 --- a/ui.py +++ b/tkinter_ui.py @@ -6,8 +6,6 @@ import PIL.ImageTk import tkinter from tkinter import ttk -# -- decide stuff -- - def regen_root(): global root global blank_image @@ -16,21 +14,19 @@ def regen_root(): root = tkinter.Tk() root.wm_resizable(False, False) - blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256, 256), mode="RGB")) + blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256 * 3, 256 * 3), mode="RGB")) padding_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(16, 16), mode="RGBA")) regen_root() -# -- -- - -meganess = None +is_mega = None def set_mega_true_and_destroy(): - global meganess - meganess = True + global is_mega + is_mega = True root.destroy() def set_mega_false_and_destroy(): - global meganess - meganess = False + global is_mega + is_mega = False root.destroy() frm = ttk.Frame(root, padding=16) @@ -40,16 +36,16 @@ ttk.Label(frm, image=padding_image).grid(column=1, row=0) ttk.Button(frm, text="Not-Mega", command=set_mega_false_and_destroy).grid(column=2, row=0) root.mainloop() -if meganess is None: - print("no option selected, goodbye") +if is_mega is None: + print("no option selected") sys.exit(0) -print("confirmed mega: ", str(meganess)) +print("confirmed mega: ", str(is_mega)) # -- -- model = MinDalle( - is_mega=meganess, + is_mega=is_mega, models_root="./pretrained", is_reusable=True, is_verbose=True @@ -63,9 +59,9 @@ regen_root() label_image_content = blank_image -sv_prompt = tkinter.StringVar(value="mouse") +sv_prompt = tkinter.StringVar(value="artificial intelligence") sv_temperature = tkinter.StringVar(value="1") -sv_topk = tkinter.StringVar(value="1024") +sv_topk = tkinter.StringVar(value="128") sv_supercond = tkinter.StringVar(value="16") def generate(): @@ -89,6 +85,7 @@ def generate(): global label_image_content image = model.generate_image( sv_prompt.get(), + grid_size=3, temperature=temperature, top_k=topk, supercondition_factor=supercond,