|
|
@ -3,10 +3,10 @@ import torch |
|
|
|
from torch import LongTensor, nn, FloatTensor, BoolTensor |
|
|
|
from torch import LongTensor, nn, FloatTensor, BoolTensor |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
|
|
|
|
|
|
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch |
|
|
|
from .dalle_bart_encoder import GLU, AttentionBase |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DecoderCrossAttentionTorch(AttentionTorch): |
|
|
|
class DecoderCrossAttention(AttentionBase): |
|
|
|
def forward( |
|
|
|
def forward( |
|
|
|
self, |
|
|
|
self, |
|
|
|
decoder_state: FloatTensor, |
|
|
|
decoder_state: FloatTensor, |
|
|
@ -19,7 +19,7 @@ class DecoderCrossAttentionTorch(AttentionTorch): |
|
|
|
return super().forward(keys, values, queries, attention_mask) |
|
|
|
return super().forward(keys, values, queries, attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DecoderSelfAttentionTorch(AttentionTorch): |
|
|
|
class DecoderSelfAttention(AttentionBase): |
|
|
|
def forward( |
|
|
|
def forward( |
|
|
|
self, |
|
|
|
self, |
|
|
|
decoder_state: FloatTensor, |
|
|
|
decoder_state: FloatTensor, |
|
|
@ -42,7 +42,7 @@ class DecoderSelfAttentionTorch(AttentionTorch): |
|
|
|
return decoder_state, attention_state |
|
|
|
return decoder_state, attention_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DecoderLayerTorch(nn.Module): |
|
|
|
class DecoderLayer(nn.Module): |
|
|
|
def __init__( |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
self, |
|
|
|
image_token_count: int, |
|
|
|
image_token_count: int, |
|
|
@ -53,12 +53,12 @@ class DecoderLayerTorch(nn.Module): |
|
|
|
super().__init__() |
|
|
|
super().__init__() |
|
|
|
self.image_token_count = image_token_count |
|
|
|
self.image_token_count = image_token_count |
|
|
|
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) |
|
|
|
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) |
|
|
|
self.self_attn = DecoderSelfAttentionTorch(head_count, embed_count) |
|
|
|
self.self_attn = DecoderSelfAttention(head_count, embed_count) |
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(embed_count) |
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(embed_count) |
|
|
|
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count) |
|
|
|
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count) |
|
|
|
self.encoder_attn = DecoderCrossAttentionTorch(head_count, embed_count) |
|
|
|
self.encoder_attn = DecoderCrossAttention(head_count, embed_count) |
|
|
|
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) |
|
|
|
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) |
|
|
|
self.glu = GLUTorch(embed_count, glu_embed_count) |
|
|
|
self.glu = GLU(embed_count, glu_embed_count) |
|
|
|
|
|
|
|
|
|
|
|
self.token_indices = torch.arange(self.image_token_count) |
|
|
|
self.token_indices = torch.arange(self.image_token_count) |
|
|
|
if torch.cuda.is_available(): |
|
|
|
if torch.cuda.is_available(): |
|
|
@ -106,7 +106,7 @@ class DecoderLayerTorch(nn.Module): |
|
|
|
return decoder_state, attention_state |
|
|
|
return decoder_state, attention_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DalleBartDecoderTorch(nn.Module): |
|
|
|
class DalleBartDecoder(nn.Module): |
|
|
|
def __init__( |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
self, |
|
|
|
image_vocab_count: int, |
|
|
|
image_vocab_count: int, |
|
|
@ -126,8 +126,8 @@ class DalleBartDecoderTorch(nn.Module): |
|
|
|
self.image_token_count = image_token_count |
|
|
|
self.image_token_count = image_token_count |
|
|
|
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) |
|
|
|
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) |
|
|
|
self.embed_positions = nn.Embedding(image_token_count, embed_count) |
|
|
|
self.embed_positions = nn.Embedding(image_token_count, embed_count) |
|
|
|
self.layers: List[DecoderLayerTorch] = nn.ModuleList([ |
|
|
|
self.layers: List[DecoderLayer] = nn.ModuleList([ |
|
|
|
DecoderLayerTorch( |
|
|
|
DecoderLayer( |
|
|
|
image_token_count, |
|
|
|
image_token_count, |
|
|
|
attention_head_count, |
|
|
|
attention_head_count, |
|
|
|
embed_count, |
|
|
|
embed_count, |