update tkinter ui
This commit is contained in:
@@ -138,11 +138,46 @@ class DalleBartDecoder(nn.Module):
|
||||
self.start_token = self.start_token.cuda()
|
||||
|
||||
|
||||
def decode_initial(
|
||||
self,
|
||||
seed: int,
|
||||
image_count: int,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
|
||||
expanded_indices = [0] * image_count + [1] * image_count
|
||||
text_tokens = text_tokens[expanded_indices]
|
||||
encoder_state = encoder_state[expanded_indices]
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
|
||||
attention_state_shape = (
|
||||
self.layer_count,
|
||||
image_count * 4,
|
||||
IMAGE_TOKEN_COUNT,
|
||||
self.embed_count
|
||||
)
|
||||
attention_state = torch.zeros(attention_state_shape)
|
||||
image_tokens_sequence = torch.full(
|
||||
(image_count, IMAGE_TOKEN_COUNT + 1),
|
||||
BLANK_TOKEN,
|
||||
dtype=torch.long
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
attention_state = attention_state.cuda()
|
||||
image_tokens_sequence = image_tokens_sequence.cuda()
|
||||
|
||||
image_tokens_sequence[:, 0] = self.start_token[0]
|
||||
|
||||
if seed > 0: torch.manual_seed(seed)
|
||||
|
||||
return encoder_state, attention_mask, attention_state, image_tokens_sequence
|
||||
|
||||
|
||||
def decode_step(
|
||||
self,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
supercondition_factor: int,
|
||||
supercondition_factor: float,
|
||||
attention_mask: BoolTensor,
|
||||
encoder_state: FloatTensor,
|
||||
attention_state: FloatTensor,
|
||||
@@ -187,7 +222,7 @@ class DalleBartDecoder(nn.Module):
|
||||
row_index: int,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
supercondition_factor: int,
|
||||
supercondition_factor: float,
|
||||
encoder_state: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
attention_state: FloatTensor,
|
||||
@@ -207,39 +242,4 @@ class DalleBartDecoder(nn.Module):
|
||||
)
|
||||
image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0]
|
||||
|
||||
return attention_state, image_tokens_sequence
|
||||
|
||||
|
||||
def decode_initial(
|
||||
self,
|
||||
seed: int,
|
||||
image_count: int,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
|
||||
expanded_indices = [0] * image_count + [1] * image_count
|
||||
text_tokens = text_tokens[expanded_indices]
|
||||
encoder_state = encoder_state[expanded_indices]
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
|
||||
attention_state_shape = (
|
||||
self.layer_count,
|
||||
image_count * 4,
|
||||
IMAGE_TOKEN_COUNT,
|
||||
self.embed_count
|
||||
)
|
||||
attention_state = torch.zeros(attention_state_shape)
|
||||
image_tokens_sequence = torch.full(
|
||||
(image_count, IMAGE_TOKEN_COUNT + 1),
|
||||
BLANK_TOKEN,
|
||||
dtype=torch.long
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
attention_state = attention_state.cuda()
|
||||
image_tokens_sequence = image_tokens_sequence.cuda()
|
||||
|
||||
image_tokens_sequence[:, 0] = self.start_token[0]
|
||||
|
||||
if seed > 0: torch.manual_seed(seed)
|
||||
|
||||
return encoder_state, attention_mask, attention_state, image_tokens_sequence
|
||||
return attention_state, image_tokens_sequence
|
Reference in New Issue
Block a user