|
|
|
@ -1,3 +1,4 @@ |
|
|
|
|
from typing import Tuple, List |
|
|
|
|
import torch |
|
|
|
|
from torch import LongTensor, nn, FloatTensor, BoolTensor |
|
|
|
|
torch.set_grad_enabled(False) |
|
|
|
@ -25,7 +26,7 @@ class DecoderSelfAttention(AttentionBase): |
|
|
|
|
attention_state: FloatTensor, |
|
|
|
|
attention_mask: BoolTensor, |
|
|
|
|
token_mask: BoolTensor |
|
|
|
|
) -> tuple[FloatTensor, FloatTensor]: |
|
|
|
|
) -> Tuple[FloatTensor, FloatTensor]: |
|
|
|
|
keys = self.k_proj.forward(decoder_state) |
|
|
|
|
values = self.v_proj.forward(decoder_state) |
|
|
|
|
queries = self.q_proj.forward(decoder_state) |
|
|
|
@ -70,7 +71,7 @@ class DecoderLayer(nn.Module): |
|
|
|
|
attention_state: FloatTensor, |
|
|
|
|
attention_mask: BoolTensor, |
|
|
|
|
token_index: LongTensor |
|
|
|
|
) -> tuple[FloatTensor, FloatTensor]: |
|
|
|
|
) -> Tuple[FloatTensor, FloatTensor]: |
|
|
|
|
# Self Attention |
|
|
|
|
residual = decoder_state |
|
|
|
|
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) |
|
|
|
@ -125,7 +126,7 @@ class DalleBartDecoder(nn.Module): |
|
|
|
|
self.image_token_count = image_token_count |
|
|
|
|
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) |
|
|
|
|
self.embed_positions = nn.Embedding(image_token_count, embed_count) |
|
|
|
|
self.layers: list[DecoderLayer] = nn.ModuleList([ |
|
|
|
|
self.layers: List[DecoderLayer] = nn.ModuleList([ |
|
|
|
|
DecoderLayer( |
|
|
|
|
image_token_count, |
|
|
|
|
attention_head_count, |
|
|
|
@ -153,7 +154,7 @@ class DalleBartDecoder(nn.Module): |
|
|
|
|
attention_state: FloatTensor, |
|
|
|
|
prev_tokens: LongTensor, |
|
|
|
|
token_index: LongTensor |
|
|
|
|
) -> tuple[LongTensor, FloatTensor]: |
|
|
|
|
) -> Tuple[LongTensor, FloatTensor]: |
|
|
|
|
image_count = encoder_state.shape[0] // 2 |
|
|
|
|
token_index_batched = token_index[[0] * image_count * 2] |
|
|
|
|
prev_tokens = prev_tokens[list(range(image_count)) * 2] |
|
|
|
@ -209,7 +210,7 @@ class DalleBartDecoder(nn.Module): |
|
|
|
|
if torch.cuda.is_available(): attention_state = attention_state.cuda() |
|
|
|
|
|
|
|
|
|
image_tokens = self.start_token[[0] * image_count] |
|
|
|
|
image_tokens_sequence: list[LongTensor] = [] |
|
|
|
|
image_tokens_sequence: List[LongTensor] = [] |
|
|
|
|
for i in range(self.sample_token_count): |
|
|
|
|
probs, attention_state = self.decode_step( |
|
|
|
|
attention_mask = attention_mask, |
|
|
|
|