update tkinter ui
This commit is contained in:
parent
ddfc806dd5
commit
a2dca41623
|
@ -138,11 +138,46 @@ class DalleBartDecoder(nn.Module):
|
|||
self.start_token = self.start_token.cuda()
|
||||
|
||||
|
||||
def decode_initial(
|
||||
self,
|
||||
seed: int,
|
||||
image_count: int,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
|
||||
expanded_indices = [0] * image_count + [1] * image_count
|
||||
text_tokens = text_tokens[expanded_indices]
|
||||
encoder_state = encoder_state[expanded_indices]
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
|
||||
attention_state_shape = (
|
||||
self.layer_count,
|
||||
image_count * 4,
|
||||
IMAGE_TOKEN_COUNT,
|
||||
self.embed_count
|
||||
)
|
||||
attention_state = torch.zeros(attention_state_shape)
|
||||
image_tokens_sequence = torch.full(
|
||||
(image_count, IMAGE_TOKEN_COUNT + 1),
|
||||
BLANK_TOKEN,
|
||||
dtype=torch.long
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
attention_state = attention_state.cuda()
|
||||
image_tokens_sequence = image_tokens_sequence.cuda()
|
||||
|
||||
image_tokens_sequence[:, 0] = self.start_token[0]
|
||||
|
||||
if seed > 0: torch.manual_seed(seed)
|
||||
|
||||
return encoder_state, attention_mask, attention_state, image_tokens_sequence
|
||||
|
||||
|
||||
def decode_step(
|
||||
self,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
supercondition_factor: int,
|
||||
supercondition_factor: float,
|
||||
attention_mask: BoolTensor,
|
||||
encoder_state: FloatTensor,
|
||||
attention_state: FloatTensor,
|
||||
|
@ -187,7 +222,7 @@ class DalleBartDecoder(nn.Module):
|
|||
row_index: int,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
supercondition_factor: int,
|
||||
supercondition_factor: float,
|
||||
encoder_state: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
attention_state: FloatTensor,
|
||||
|
@ -207,39 +242,4 @@ class DalleBartDecoder(nn.Module):
|
|||
)
|
||||
image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0]
|
||||
|
||||
return attention_state, image_tokens_sequence
|
||||
|
||||
|
||||
def decode_initial(
|
||||
self,
|
||||
seed: int,
|
||||
image_count: int,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
|
||||
expanded_indices = [0] * image_count + [1] * image_count
|
||||
text_tokens = text_tokens[expanded_indices]
|
||||
encoder_state = encoder_state[expanded_indices]
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
|
||||
attention_state_shape = (
|
||||
self.layer_count,
|
||||
image_count * 4,
|
||||
IMAGE_TOKEN_COUNT,
|
||||
self.embed_count
|
||||
)
|
||||
attention_state = torch.zeros(attention_state_shape)
|
||||
image_tokens_sequence = torch.full(
|
||||
(image_count, IMAGE_TOKEN_COUNT + 1),
|
||||
BLANK_TOKEN,
|
||||
dtype=torch.long
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
attention_state = attention_state.cuda()
|
||||
image_tokens_sequence = image_tokens_sequence.cuda()
|
||||
|
||||
image_tokens_sequence[:, 0] = self.start_token[0]
|
||||
|
||||
if seed > 0: torch.manual_seed(seed)
|
||||
|
||||
return encoder_state, attention_mask, attention_state, image_tokens_sequence
|
||||
return attention_state, image_tokens_sequence
|
|
@ -6,8 +6,6 @@ import PIL.ImageTk
|
|||
import tkinter
|
||||
from tkinter import ttk
|
||||
|
||||
# -- decide stuff --
|
||||
|
||||
def regen_root():
|
||||
global root
|
||||
global blank_image
|
||||
|
@ -16,21 +14,19 @@ def regen_root():
|
|||
root = tkinter.Tk()
|
||||
root.wm_resizable(False, False)
|
||||
|
||||
blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256, 256), mode="RGB"))
|
||||
blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256 * 3, 256 * 3), mode="RGB"))
|
||||
padding_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(16, 16), mode="RGBA"))
|
||||
|
||||
regen_root()
|
||||
|
||||
# -- --
|
||||
|
||||
meganess = None
|
||||
is_mega = None
|
||||
def set_mega_true_and_destroy():
|
||||
global meganess
|
||||
meganess = True
|
||||
global is_mega
|
||||
is_mega = True
|
||||
root.destroy()
|
||||
def set_mega_false_and_destroy():
|
||||
global meganess
|
||||
meganess = False
|
||||
global is_mega
|
||||
is_mega = False
|
||||
root.destroy()
|
||||
|
||||
frm = ttk.Frame(root, padding=16)
|
||||
|
@ -40,16 +36,16 @@ ttk.Label(frm, image=padding_image).grid(column=1, row=0)
|
|||
ttk.Button(frm, text="Not-Mega", command=set_mega_false_and_destroy).grid(column=2, row=0)
|
||||
root.mainloop()
|
||||
|
||||
if meganess is None:
|
||||
print("no option selected, goodbye")
|
||||
if is_mega is None:
|
||||
print("no option selected")
|
||||
sys.exit(0)
|
||||
|
||||
print("confirmed mega: ", str(meganess))
|
||||
print("confirmed mega: ", str(is_mega))
|
||||
|
||||
# -- --
|
||||
|
||||
model = MinDalle(
|
||||
is_mega=meganess,
|
||||
is_mega=is_mega,
|
||||
models_root="./pretrained",
|
||||
is_reusable=True,
|
||||
is_verbose=True
|
||||
|
@ -63,9 +59,9 @@ regen_root()
|
|||
|
||||
label_image_content = blank_image
|
||||
|
||||
sv_prompt = tkinter.StringVar(value="mouse")
|
||||
sv_prompt = tkinter.StringVar(value="artificial intelligence")
|
||||
sv_temperature = tkinter.StringVar(value="1")
|
||||
sv_topk = tkinter.StringVar(value="1024")
|
||||
sv_topk = tkinter.StringVar(value="128")
|
||||
sv_supercond = tkinter.StringVar(value="16")
|
||||
|
||||
def generate():
|
||||
|
@ -89,6 +85,7 @@ def generate():
|
|||
global label_image_content
|
||||
image = model.generate_image(
|
||||
sv_prompt.get(),
|
||||
grid_size=3,
|
||||
temperature=temperature,
|
||||
top_k=topk,
|
||||
supercondition_factor=supercond,
|
Loading…
Reference in New Issue
Block a user