@ -1,4 +1,4 @@
from typing import List , Tuple
from typing import Tuple , List
import torch
from torch import LongTensor , nn , FloatTensor , BoolTensor
torch . set_grad_enabled ( False )
@ -76,8 +76,8 @@ class DecoderLayer(nn.Module):
residual = decoder_state
decoder_state = self . pre_self_attn_layer_norm . forward ( decoder_state )
self_attn_mask = self . token_indices < token_index + 1
self_attn_mask = self_attn_mask [ None ] [ [ 0 ] * decoder_state . shape [ 0 ] ]
token_mask = self . token_indices == token_index
self_attn_mask = torch . stack ( [ self_attn_mask ] * decoder_state . shape [ 0 ] )
decoder_state , attention_state = self . self_attn . forward (
decoder_state ,
attention_state ,
@ -116,11 +116,11 @@ class DalleBartDecoder(nn.Module):
attention_head_count : int ,
glu_embed_count : int ,
layer_count : int ,
batch_count : int ,
start_token : int
) :
super ( ) . __init__ ( )
self . layer_count = layer_count
self . embed_count = embed_count
self . sample_token_count = sample_token_count
self . condition_factor = 10.0
self . image_token_count = image_token_count
@ -138,12 +138,6 @@ class DalleBartDecoder(nn.Module):
self . layernorm_embedding = nn . LayerNorm ( embed_count )
self . final_ln = nn . LayerNorm ( embed_count )
self . lm_head = nn . Linear ( embed_count , image_vocab_count + 1 , bias = False )
self . attention_state_shape = (
layer_count ,
2 * batch_count ,
image_token_count ,
embed_count
)
self . zero_prob = torch . zeros ( [ 1 ] )
self . token_indices = torch . arange ( self . sample_token_count )
self . start_token = torch . tensor ( [ start_token ] ) . to ( torch . long )
@ -155,17 +149,16 @@ class DalleBartDecoder(nn.Module):
def decode_step (
self ,
text_tokens : Long Tensor,
attention_mask : Bool Tensor,
encoder_state : FloatTensor ,
attention_state : FloatTensor ,
prev_token : LongTensor ,
prev_tokens : LongTensor ,
token_index : LongTensor
) - > Tuple [ LongTensor , FloatTensor ] :
attention_mask = text_tokens . not_equal ( 1 )
batch_count = encoder_state . shape [ 0 ]
prev_token_batched = torch . cat ( [ prev_token ] * batch_count )
token_index_batched = torch . cat ( [ token_index ] * batch_count )
decoder_state = self . embed_tokens . forward ( prev_token_batched )
image_count = encoder_state . shape [ 0 ] / / 2
token_index_batched = token_index [ [ 0 ] * image_count * 2 ]
prev_tokens = prev_tokens [ list ( range ( image_count ) ) * 2 ]
decoder_state = self . embed_tokens . forward ( prev_tokens )
decoder_state + = self . embed_positions . forward ( token_index_batched )
decoder_state = self . layernorm_embedding . forward ( decoder_state )
decoder_state = decoder_state [ : , None ]
@ -182,38 +175,52 @@ class DalleBartDecoder(nn.Module):
decoder_state = self . final_ln ( decoder_state )
logits = self . lm_head ( decoder_state )
a = self . condition_factor
logits : FloatTensor = ( 1 - a ) * logits [ 0 , - 1 ] + a * logits [ 1 , - 1 ]
logits : FloatTensor = (
logits [ : image_count , - 1 ] * ( 1 - a ) +
logits [ image_count : , - 1 ] * a
)
top_logits , _ = logits . topk ( 50 , dim = - 1 )
probs = torch . where (
logits < top_logits [ - 1 ] ,
logits < top_logits [ : , [ - 1 ] ] ,
self . zero_prob ,
torch . exp ( logits - top_logits [ 0 ] )
torch . exp ( logits - top_logits [ : , [ 0 ] ] )
)
return probs , torch . stack ( attention_states_new )
def forward (
self ,
image_count : int ,
text_tokens : LongTensor ,
encoder_state : FloatTensor
) - > LongTensor :
image_tokens : List [ LongTensor ] = [ ]
attention_state = torch . zeros ( self . attention_state_shape )
if torch . cuda . is_available ( ) :
attention_state = attention_state . cuda ( )
image_token = self . start_token
expanded_indices = [ 0 ] * image_count + [ 1 ] * image_count
text_tokens = text_tokens [ expanded_indices ]
encoder_state = encoder_state [ expanded_indices ]
attention_mask = text_tokens . not_equal ( 1 )
attention_state_shape = (
self . layer_count ,
image_count * 4 ,
self . image_token_count ,
self . embed_count
)
attention_state = torch . zeros ( attention_state_shape )
if torch . cuda . is_available ( ) : attention_state = attention_state . cuda ( )
image_tokens = self . start_token [ [ 0 ] * image_count ]
image_tokens_sequence : List [ LongTensor ] = [ ]
for i in range ( self . sample_token_count ) :
probs , attention_state = self . decode_step (
text_tokens = text_tokens ,
attention_mask = attention_mask ,
encoder_state = encoder_state ,
attention_state = attention_state ,
prev_token = image_token ,
prev_tokens = image_tokens ,
token_index = self . token_indices [ [ i ] ]
)
image_token = torch . multinomial ( probs , 1 )
image_tokens + = [ image_token ]
return torch . cat ( image_tokens )
image_tokens = torch . multinomial ( probs , 1 ) [ : , 0 ]
image_tokens_sequence + = [ image_tokens ]
return torch . stack ( image_tokens_sequence ) . T