@ -20,9 +20,9 @@ class DecoderCrossAttention(AttentionBase):
class DecoderSelfAttention ( AttentionBase ) :
def __init__ ( self , head_count : int , embed_count : int ) :
def __init__ ( self , head_count : int , embed_count : int , token_count : int ) :
super ( ) . __init__ ( head_count , embed_count )
token_indices = torch . arange ( 256 )
token_indices = torch . arange ( token_count )
if torch . cuda . is_available ( ) : token_indices = token_indices . cuda ( )
self . token_indices = token_indices
@ -56,7 +56,11 @@ class DecoderLayer(nn.Module):
super ( ) . __init__ ( )
self . image_token_count = image_token_count
self . pre_self_attn_layer_norm = nn . LayerNorm ( embed_count )
self . self_attn = DecoderSelfAttention ( head_count , embed_count )
self . self_attn = DecoderSelfAttention (
head_count ,
embed_count ,
image_token_count
)
self . self_attn_layer_norm = nn . LayerNorm ( embed_count )
self . pre_encoder_attn_layer_norm = nn . LayerNorm ( embed_count )
self . encoder_attn = DecoderCrossAttention ( head_count , embed_count )
@ -150,7 +154,7 @@ class DalleBartDecoder(nn.Module):
attention_state : FloatTensor ,
prev_tokens : LongTensor ,
token_index : LongTensor
) - > Tuple [ Long Tensor, FloatTensor ] :
) - > Tuple [ Float Tensor, FloatTensor ] :
image_count = encoder_state . shape [ 0 ] / / 2
token_index_batched = token_index [ [ 0 ] * image_count * 2 ]
prev_tokens = prev_tokens [ list ( range ( image_count ) ) * 2 ]
@ -158,16 +162,14 @@ class DalleBartDecoder(nn.Module):
decoder_state + = self . embed_positions . forward ( token_index_batched )
decoder_state = self . layernorm_embedding . forward ( decoder_state )
decoder_state = decoder_state [ : , None ]
attention_states_new = [ ]
for i in range ( self . layer_count ) :
decoder_state , attention_state_layer = self . layers [ i ] . forward (
decoder_state , attention_state [ i ] = self . layers [ i ] . forward (
decoder_state ,
encoder_state ,
attention_state [ i ] ,
attention_mask ,
token_index
)
attention_states_new . append ( attention_state_layer )
decoder_state = self . final_ln ( decoder_state )
logits = self . lm_head ( decoder_state )
a = self . condition_factor
@ -182,7 +184,7 @@ class DalleBartDecoder(nn.Module):
self . zero_prob ,
torch . exp ( logits - top_logits [ : , [ 0 ] ] )
)
return probs , torch . stack ( attention_states_new )
return probs , attention_state
def forward (
@ -203,10 +205,17 @@ class DalleBartDecoder(nn.Module):
self . embed_count
)
attention_state = torch . zeros ( attention_state_shape )
if torch . cuda . is_available ( ) : attention_state = attention_state . cuda ( )
image_tokens_sequence = torch . full (
( image_count , self . image_token_count ) ,
6965 , # black token
dtype = torch . long
)
if torch . cuda . is_available ( ) :
attention_state = attention_state . cuda ( )
image_tokens_sequence = image_tokens_sequence . cuda ( )
image_tokens = self . start_token [ [ 0 ] * image_count ]
image_tokens_sequence : List [ LongTensor ] = [ ]
for i in range ( self . sample_token_count ) :
probs , attention_state = self . decode_step (
attention_mask = attention_mask ,
@ -217,6 +226,6 @@ class DalleBartDecoder(nn.Module):
)
image_tokens = torch . multinomial ( probs , 1 ) [ : , 0 ]
image_tokens_sequence + = [ image_tokens ]
image_tokens_sequence [ : , i ] = image_tokens
return torch . stack ( image_tokens_sequence ) . T
return image_tokens_sequence