|
|
|
@ -1,7 +1,7 @@ |
|
|
|
|
from typing import List, Tuple |
|
|
|
|
import torch |
|
|
|
|
from torch import LongTensor, nn, FloatTensor, BoolTensor |
|
|
|
|
torch.no_grad() |
|
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
|
|
|
|
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch |
|
|
|
|
|
|
|
|
@ -30,14 +30,12 @@ class DecoderSelfAttentionTorch(AttentionTorch): |
|
|
|
|
decoder_state: FloatTensor, |
|
|
|
|
keys_values: FloatTensor, |
|
|
|
|
attention_mask: BoolTensor, |
|
|
|
|
token_index: LongTensor |
|
|
|
|
token_mask: BoolTensor |
|
|
|
|
) -> Tuple[FloatTensor, FloatTensor]: |
|
|
|
|
batch_count = decoder_state.shape[0] |
|
|
|
|
token_count = keys_values.shape[1] |
|
|
|
|
shape = (batch_count, 1) + keys_values.shape[2:] |
|
|
|
|
keys = self.k_proj.forward(decoder_state).view(shape) |
|
|
|
|
values = self.v_proj.forward(decoder_state).view(shape) |
|
|
|
|
token_mask = torch.arange(token_count) == token_index |
|
|
|
|
keys_values = torch.where( |
|
|
|
|
token_mask[None, :, None, None], |
|
|
|
|
torch.cat([keys, values]), |
|
|
|
@ -67,6 +65,10 @@ class DecoderLayerTorch(nn.Module): |
|
|
|
|
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) |
|
|
|
|
self.glu = GLUTorch(embed_count, glu_embed_count) |
|
|
|
|
|
|
|
|
|
self.token_indices = torch.arange(self.image_token_count) |
|
|
|
|
if torch.cuda.is_available(): |
|
|
|
|
self.token_indices = self.token_indices.cuda() |
|
|
|
|
|
|
|
|
|
def forward(self, |
|
|
|
|
decoder_state: FloatTensor, |
|
|
|
|
encoder_state: FloatTensor, |
|
|
|
@ -77,13 +79,14 @@ class DecoderLayerTorch(nn.Module): |
|
|
|
|
# Self Attention |
|
|
|
|
residual = decoder_state |
|
|
|
|
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) |
|
|
|
|
self_attn_mask = torch.arange(self.image_token_count) < token_index + 1 |
|
|
|
|
self_attn_mask = self.token_indices < token_index + 1 |
|
|
|
|
token_mask = self.token_indices == token_index |
|
|
|
|
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0]) |
|
|
|
|
decoder_state, keys_values_state = self.self_attn.forward( |
|
|
|
|
decoder_state, |
|
|
|
|
keys_values_state, |
|
|
|
|
self_attn_mask, |
|
|
|
|
token_index |
|
|
|
|
token_mask |
|
|
|
|
) |
|
|
|
|
decoder_state = self.self_attn_layer_norm.forward(decoder_state) |
|
|
|
|
decoder_state = residual + decoder_state |
|
|
|
@ -124,13 +127,7 @@ class DalleBartDecoderTorch(nn.Module): |
|
|
|
|
self.is_verbose = is_verbose |
|
|
|
|
self.layer_count = layer_count |
|
|
|
|
self.sample_token_count = sample_token_count |
|
|
|
|
self.start_token = torch.tensor([start_token]).to(torch.long) |
|
|
|
|
self.pad_token = torch.tensor([1]).to(torch.long) |
|
|
|
|
self.condition_factor = torch.tensor([10]).to(torch.float) |
|
|
|
|
# if torch.cuda.is_available(): |
|
|
|
|
# self.start_token = self.start_token.cuda() |
|
|
|
|
# self.pad_token = self.pad_token.cuda() |
|
|
|
|
# self.condition_factor = self.condition_factor.cuda() |
|
|
|
|
self.condition_factor = 10.0 |
|
|
|
|
self.image_token_count = image_token_count |
|
|
|
|
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count) |
|
|
|
|
self.embed_positions = nn.Embedding(image_token_count, embed_count) |
|
|
|
@ -152,6 +149,13 @@ class DalleBartDecoderTorch(nn.Module): |
|
|
|
|
attention_head_count, |
|
|
|
|
embed_count // attention_head_count |
|
|
|
|
) |
|
|
|
|
self.zero_prob = torch.zeros([1]) |
|
|
|
|
self.token_indices = torch.arange(self.sample_token_count) |
|
|
|
|
self.start_token = torch.tensor([start_token]).to(torch.long) |
|
|
|
|
if torch.cuda.is_available(): |
|
|
|
|
self.zero_prob = self.zero_prob.cuda() |
|
|
|
|
self.token_indices = self.token_indices.cuda() |
|
|
|
|
self.start_token = self.start_token.cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_step(self, |
|
|
|
@ -160,7 +164,7 @@ class DalleBartDecoderTorch(nn.Module): |
|
|
|
|
keys_values_state: FloatTensor, |
|
|
|
|
prev_token_and_index: LongTensor |
|
|
|
|
) -> Tuple[LongTensor, FloatTensor]: |
|
|
|
|
attention_mask = text_tokens.not_equal(self.pad_token) |
|
|
|
|
attention_mask = text_tokens.not_equal(1) |
|
|
|
|
batch_count = encoder_state.shape[0] |
|
|
|
|
prev_token = torch.cat([prev_token_and_index[:1]] * batch_count) |
|
|
|
|
token_index = torch.cat([prev_token_and_index[1:]] * batch_count) |
|
|
|
@ -188,7 +192,7 @@ class DalleBartDecoderTorch(nn.Module): |
|
|
|
|
top_logits = logits.sort(descending=True)[0][:50] |
|
|
|
|
probs = torch.where( |
|
|
|
|
logits < top_logits[-1], |
|
|
|
|
torch.zeros([1]), |
|
|
|
|
self.zero_prob, |
|
|
|
|
torch.exp(logits - top_logits[0]) |
|
|
|
|
) |
|
|
|
|
return probs, keys_values |
|
|
|
@ -200,11 +204,12 @@ class DalleBartDecoderTorch(nn.Module): |
|
|
|
|
) -> LongTensor: |
|
|
|
|
image_tokens: List[LongTensor] = [] |
|
|
|
|
keys_values_state = torch.zeros(self.keys_values_state_shape) |
|
|
|
|
if torch.cuda.is_available(): |
|
|
|
|
keys_values_state = keys_values_state.cuda() |
|
|
|
|
image_token = self.start_token |
|
|
|
|
|
|
|
|
|
for i in range(self.sample_token_count): |
|
|
|
|
token_index = torch.tensor([i]).to(torch.long) |
|
|
|
|
# if torch.cuda.is_available(): token_index = token_index.cuda() |
|
|
|
|
token_index = self.token_indices[i:i+1] |
|
|
|
|
probs, keys_values_state = self.decode_step( |
|
|
|
|
text_tokens = text_tokens, |
|
|
|
|
encoder_state = encoder_state, |
|
|
|
@ -214,9 +219,5 @@ class DalleBartDecoderTorch(nn.Module): |
|
|
|
|
|
|
|
|
|
image_token = torch.multinomial(probs, 1) |
|
|
|
|
image_tokens += [image_token] |
|
|
|
|
|
|
|
|
|
if self.is_verbose: |
|
|
|
|
token = int(image_token.detach().numpy()) |
|
|
|
|
print("image token {} is {}".format(i, token)) |
|
|
|
|
|
|
|
|
|
return torch.cat(image_tokens) |