update tkinter ui
This commit is contained in:
parent
ddfc806dd5
commit
a2dca41623
|
@ -138,11 +138,46 @@ class DalleBartDecoder(nn.Module):
|
||||||
self.start_token = self.start_token.cuda()
|
self.start_token = self.start_token.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_initial(
|
||||||
|
self,
|
||||||
|
seed: int,
|
||||||
|
image_count: int,
|
||||||
|
text_tokens: LongTensor,
|
||||||
|
encoder_state: FloatTensor
|
||||||
|
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
|
||||||
|
expanded_indices = [0] * image_count + [1] * image_count
|
||||||
|
text_tokens = text_tokens[expanded_indices]
|
||||||
|
encoder_state = encoder_state[expanded_indices]
|
||||||
|
attention_mask = text_tokens.not_equal(1)
|
||||||
|
|
||||||
|
attention_state_shape = (
|
||||||
|
self.layer_count,
|
||||||
|
image_count * 4,
|
||||||
|
IMAGE_TOKEN_COUNT,
|
||||||
|
self.embed_count
|
||||||
|
)
|
||||||
|
attention_state = torch.zeros(attention_state_shape)
|
||||||
|
image_tokens_sequence = torch.full(
|
||||||
|
(image_count, IMAGE_TOKEN_COUNT + 1),
|
||||||
|
BLANK_TOKEN,
|
||||||
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
attention_state = attention_state.cuda()
|
||||||
|
image_tokens_sequence = image_tokens_sequence.cuda()
|
||||||
|
|
||||||
|
image_tokens_sequence[:, 0] = self.start_token[0]
|
||||||
|
|
||||||
|
if seed > 0: torch.manual_seed(seed)
|
||||||
|
|
||||||
|
return encoder_state, attention_mask, attention_state, image_tokens_sequence
|
||||||
|
|
||||||
|
|
||||||
def decode_step(
|
def decode_step(
|
||||||
self,
|
self,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
supercondition_factor: int,
|
supercondition_factor: float,
|
||||||
attention_mask: BoolTensor,
|
attention_mask: BoolTensor,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
attention_state: FloatTensor,
|
attention_state: FloatTensor,
|
||||||
|
@ -187,7 +222,7 @@ class DalleBartDecoder(nn.Module):
|
||||||
row_index: int,
|
row_index: int,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
supercondition_factor: int,
|
supercondition_factor: float,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
attention_mask: BoolTensor,
|
attention_mask: BoolTensor,
|
||||||
attention_state: FloatTensor,
|
attention_state: FloatTensor,
|
||||||
|
@ -208,38 +243,3 @@ class DalleBartDecoder(nn.Module):
|
||||||
image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0]
|
image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0]
|
||||||
|
|
||||||
return attention_state, image_tokens_sequence
|
return attention_state, image_tokens_sequence
|
||||||
|
|
||||||
|
|
||||||
def decode_initial(
|
|
||||||
self,
|
|
||||||
seed: int,
|
|
||||||
image_count: int,
|
|
||||||
text_tokens: LongTensor,
|
|
||||||
encoder_state: FloatTensor
|
|
||||||
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
|
|
||||||
expanded_indices = [0] * image_count + [1] * image_count
|
|
||||||
text_tokens = text_tokens[expanded_indices]
|
|
||||||
encoder_state = encoder_state[expanded_indices]
|
|
||||||
attention_mask = text_tokens.not_equal(1)
|
|
||||||
|
|
||||||
attention_state_shape = (
|
|
||||||
self.layer_count,
|
|
||||||
image_count * 4,
|
|
||||||
IMAGE_TOKEN_COUNT,
|
|
||||||
self.embed_count
|
|
||||||
)
|
|
||||||
attention_state = torch.zeros(attention_state_shape)
|
|
||||||
image_tokens_sequence = torch.full(
|
|
||||||
(image_count, IMAGE_TOKEN_COUNT + 1),
|
|
||||||
BLANK_TOKEN,
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
attention_state = attention_state.cuda()
|
|
||||||
image_tokens_sequence = image_tokens_sequence.cuda()
|
|
||||||
|
|
||||||
image_tokens_sequence[:, 0] = self.start_token[0]
|
|
||||||
|
|
||||||
if seed > 0: torch.manual_seed(seed)
|
|
||||||
|
|
||||||
return encoder_state, attention_mask, attention_state, image_tokens_sequence
|
|
|
@ -6,8 +6,6 @@ import PIL.ImageTk
|
||||||
import tkinter
|
import tkinter
|
||||||
from tkinter import ttk
|
from tkinter import ttk
|
||||||
|
|
||||||
# -- decide stuff --
|
|
||||||
|
|
||||||
def regen_root():
|
def regen_root():
|
||||||
global root
|
global root
|
||||||
global blank_image
|
global blank_image
|
||||||
|
@ -16,21 +14,19 @@ def regen_root():
|
||||||
root = tkinter.Tk()
|
root = tkinter.Tk()
|
||||||
root.wm_resizable(False, False)
|
root.wm_resizable(False, False)
|
||||||
|
|
||||||
blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256, 256), mode="RGB"))
|
blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256 * 3, 256 * 3), mode="RGB"))
|
||||||
padding_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(16, 16), mode="RGBA"))
|
padding_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(16, 16), mode="RGBA"))
|
||||||
|
|
||||||
regen_root()
|
regen_root()
|
||||||
|
|
||||||
# -- --
|
is_mega = None
|
||||||
|
|
||||||
meganess = None
|
|
||||||
def set_mega_true_and_destroy():
|
def set_mega_true_and_destroy():
|
||||||
global meganess
|
global is_mega
|
||||||
meganess = True
|
is_mega = True
|
||||||
root.destroy()
|
root.destroy()
|
||||||
def set_mega_false_and_destroy():
|
def set_mega_false_and_destroy():
|
||||||
global meganess
|
global is_mega
|
||||||
meganess = False
|
is_mega = False
|
||||||
root.destroy()
|
root.destroy()
|
||||||
|
|
||||||
frm = ttk.Frame(root, padding=16)
|
frm = ttk.Frame(root, padding=16)
|
||||||
|
@ -40,16 +36,16 @@ ttk.Label(frm, image=padding_image).grid(column=1, row=0)
|
||||||
ttk.Button(frm, text="Not-Mega", command=set_mega_false_and_destroy).grid(column=2, row=0)
|
ttk.Button(frm, text="Not-Mega", command=set_mega_false_and_destroy).grid(column=2, row=0)
|
||||||
root.mainloop()
|
root.mainloop()
|
||||||
|
|
||||||
if meganess is None:
|
if is_mega is None:
|
||||||
print("no option selected, goodbye")
|
print("no option selected")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
print("confirmed mega: ", str(meganess))
|
print("confirmed mega: ", str(is_mega))
|
||||||
|
|
||||||
# -- --
|
# -- --
|
||||||
|
|
||||||
model = MinDalle(
|
model = MinDalle(
|
||||||
is_mega=meganess,
|
is_mega=is_mega,
|
||||||
models_root="./pretrained",
|
models_root="./pretrained",
|
||||||
is_reusable=True,
|
is_reusable=True,
|
||||||
is_verbose=True
|
is_verbose=True
|
||||||
|
@ -63,9 +59,9 @@ regen_root()
|
||||||
|
|
||||||
label_image_content = blank_image
|
label_image_content = blank_image
|
||||||
|
|
||||||
sv_prompt = tkinter.StringVar(value="mouse")
|
sv_prompt = tkinter.StringVar(value="artificial intelligence")
|
||||||
sv_temperature = tkinter.StringVar(value="1")
|
sv_temperature = tkinter.StringVar(value="1")
|
||||||
sv_topk = tkinter.StringVar(value="1024")
|
sv_topk = tkinter.StringVar(value="128")
|
||||||
sv_supercond = tkinter.StringVar(value="16")
|
sv_supercond = tkinter.StringVar(value="16")
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
|
@ -89,6 +85,7 @@ def generate():
|
||||||
global label_image_content
|
global label_image_content
|
||||||
image = model.generate_image(
|
image = model.generate_image(
|
||||||
sv_prompt.get(),
|
sv_prompt.get(),
|
||||||
|
grid_size=3,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=topk,
|
top_k=topk,
|
||||||
supercondition_factor=supercond,
|
supercondition_factor=supercond,
|
Loading…
Reference in New Issue
Block a user