update tkinter ui

main
Brett Kuprel 2 years ago
parent ddfc806dd5
commit a2dca41623
  1. 76
      min_dalle/models/dalle_bart_decoder.py
  2. 29
      tkinter_ui.py

@ -138,11 +138,46 @@ class DalleBartDecoder(nn.Module):
self.start_token = self.start_token.cuda()
def decode_initial(
self,
seed: int,
image_count: int,
text_tokens: LongTensor,
encoder_state: FloatTensor
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
expanded_indices = [0] * image_count + [1] * image_count
text_tokens = text_tokens[expanded_indices]
encoder_state = encoder_state[expanded_indices]
attention_mask = text_tokens.not_equal(1)
attention_state_shape = (
self.layer_count,
image_count * 4,
IMAGE_TOKEN_COUNT,
self.embed_count
)
attention_state = torch.zeros(attention_state_shape)
image_tokens_sequence = torch.full(
(image_count, IMAGE_TOKEN_COUNT + 1),
BLANK_TOKEN,
dtype=torch.long
)
if torch.cuda.is_available():
attention_state = attention_state.cuda()
image_tokens_sequence = image_tokens_sequence.cuda()
image_tokens_sequence[:, 0] = self.start_token[0]
if seed > 0: torch.manual_seed(seed)
return encoder_state, attention_mask, attention_state, image_tokens_sequence
def decode_step(
self,
temperature: float,
top_k: int,
supercondition_factor: int,
supercondition_factor: float,
attention_mask: BoolTensor,
encoder_state: FloatTensor,
attention_state: FloatTensor,
@ -187,7 +222,7 @@ class DalleBartDecoder(nn.Module):
row_index: int,
temperature: float,
top_k: int,
supercondition_factor: int,
supercondition_factor: float,
encoder_state: FloatTensor,
attention_mask: BoolTensor,
attention_state: FloatTensor,
@ -207,39 +242,4 @@ class DalleBartDecoder(nn.Module):
)
image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0]
return attention_state, image_tokens_sequence
def decode_initial(
self,
seed: int,
image_count: int,
text_tokens: LongTensor,
encoder_state: FloatTensor
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
expanded_indices = [0] * image_count + [1] * image_count
text_tokens = text_tokens[expanded_indices]
encoder_state = encoder_state[expanded_indices]
attention_mask = text_tokens.not_equal(1)
attention_state_shape = (
self.layer_count,
image_count * 4,
IMAGE_TOKEN_COUNT,
self.embed_count
)
attention_state = torch.zeros(attention_state_shape)
image_tokens_sequence = torch.full(
(image_count, IMAGE_TOKEN_COUNT + 1),
BLANK_TOKEN,
dtype=torch.long
)
if torch.cuda.is_available():
attention_state = attention_state.cuda()
image_tokens_sequence = image_tokens_sequence.cuda()
image_tokens_sequence[:, 0] = self.start_token[0]
if seed > 0: torch.manual_seed(seed)
return encoder_state, attention_mask, attention_state, image_tokens_sequence
return attention_state, image_tokens_sequence

@ -6,8 +6,6 @@ import PIL.ImageTk
import tkinter
from tkinter import ttk
# -- decide stuff --
def regen_root():
global root
global blank_image
@ -16,21 +14,19 @@ def regen_root():
root = tkinter.Tk()
root.wm_resizable(False, False)
blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256, 256), mode="RGB"))
blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256 * 3, 256 * 3), mode="RGB"))
padding_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(16, 16), mode="RGBA"))
regen_root()
# -- --
meganess = None
is_mega = None
def set_mega_true_and_destroy():
global meganess
meganess = True
global is_mega
is_mega = True
root.destroy()
def set_mega_false_and_destroy():
global meganess
meganess = False
global is_mega
is_mega = False
root.destroy()
frm = ttk.Frame(root, padding=16)
@ -40,16 +36,16 @@ ttk.Label(frm, image=padding_image).grid(column=1, row=0)
ttk.Button(frm, text="Not-Mega", command=set_mega_false_and_destroy).grid(column=2, row=0)
root.mainloop()
if meganess is None:
print("no option selected, goodbye")
if is_mega is None:
print("no option selected")
sys.exit(0)
print("confirmed mega: ", str(meganess))
print("confirmed mega: ", str(is_mega))
# -- --
model = MinDalle(
is_mega=meganess,
is_mega=is_mega,
models_root="./pretrained",
is_reusable=True,
is_verbose=True
@ -63,9 +59,9 @@ regen_root()
label_image_content = blank_image
sv_prompt = tkinter.StringVar(value="mouse")
sv_prompt = tkinter.StringVar(value="artificial intelligence")
sv_temperature = tkinter.StringVar(value="1")
sv_topk = tkinter.StringVar(value="1024")
sv_topk = tkinter.StringVar(value="128")
sv_supercond = tkinter.StringVar(value="16")
def generate():
@ -89,6 +85,7 @@ def generate():
global label_image_content
image = model.generate_image(
sv_prompt.get(),
grid_size=3,
temperature=temperature,
top_k=topk,
supercondition_factor=supercond,
Loading…
Cancel
Save