update tkinter ui

This commit is contained in:
Brett Kuprel
2022-07-15 09:00:26 -04:00
parent ddfc806dd5
commit a2dca41623
2 changed files with 51 additions and 54 deletions

View File

@@ -138,11 +138,46 @@ class DalleBartDecoder(nn.Module):
self.start_token = self.start_token.cuda()
def decode_initial(
self,
seed: int,
image_count: int,
text_tokens: LongTensor,
encoder_state: FloatTensor
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
expanded_indices = [0] * image_count + [1] * image_count
text_tokens = text_tokens[expanded_indices]
encoder_state = encoder_state[expanded_indices]
attention_mask = text_tokens.not_equal(1)
attention_state_shape = (
self.layer_count,
image_count * 4,
IMAGE_TOKEN_COUNT,
self.embed_count
)
attention_state = torch.zeros(attention_state_shape)
image_tokens_sequence = torch.full(
(image_count, IMAGE_TOKEN_COUNT + 1),
BLANK_TOKEN,
dtype=torch.long
)
if torch.cuda.is_available():
attention_state = attention_state.cuda()
image_tokens_sequence = image_tokens_sequence.cuda()
image_tokens_sequence[:, 0] = self.start_token[0]
if seed > 0: torch.manual_seed(seed)
return encoder_state, attention_mask, attention_state, image_tokens_sequence
def decode_step(
self,
temperature: float,
top_k: int,
supercondition_factor: int,
supercondition_factor: float,
attention_mask: BoolTensor,
encoder_state: FloatTensor,
attention_state: FloatTensor,
@@ -187,7 +222,7 @@ class DalleBartDecoder(nn.Module):
row_index: int,
temperature: float,
top_k: int,
supercondition_factor: int,
supercondition_factor: float,
encoder_state: FloatTensor,
attention_mask: BoolTensor,
attention_state: FloatTensor,
@@ -207,39 +242,4 @@ class DalleBartDecoder(nn.Module):
)
image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0]
return attention_state, image_tokens_sequence
def decode_initial(
self,
seed: int,
image_count: int,
text_tokens: LongTensor,
encoder_state: FloatTensor
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
expanded_indices = [0] * image_count + [1] * image_count
text_tokens = text_tokens[expanded_indices]
encoder_state = encoder_state[expanded_indices]
attention_mask = text_tokens.not_equal(1)
attention_state_shape = (
self.layer_count,
image_count * 4,
IMAGE_TOKEN_COUNT,
self.embed_count
)
attention_state = torch.zeros(attention_state_shape)
image_tokens_sequence = torch.full(
(image_count, IMAGE_TOKEN_COUNT + 1),
BLANK_TOKEN,
dtype=torch.long
)
if torch.cuda.is_available():
attention_state = attention_state.cuda()
image_tokens_sequence = image_tokens_sequence.cuda()
image_tokens_sequence[:, 0] = self.start_token[0]
if seed > 0: torch.manual_seed(seed)
return encoder_state, attention_mask, attention_state, image_tokens_sequence
return attention_state, image_tokens_sequence