@ -5,6 +5,9 @@ torch.set_grad_enabled(False)
from . dalle_bart_encoder import GLU , AttentionBase
from . dalle_bart_encoder import GLU , AttentionBase
IMAGE_TOKEN_COUNT = 256
BLANK_TOKEN = 6965
class DecoderCrossAttention ( AttentionBase ) :
class DecoderCrossAttention ( AttentionBase ) :
def forward (
def forward (
@ -20,9 +23,9 @@ class DecoderCrossAttention(AttentionBase):
class DecoderSelfAttention ( AttentionBase ) :
class DecoderSelfAttention ( AttentionBase ) :
def __init__ ( self , head_count : int , embed_count : int , token_count : int ) :
def __init__ ( self , head_count : int , embed_count : int ) :
super ( ) . __init__ ( head_count , embed_count )
super ( ) . __init__ ( head_count , embed_count )
token_indices = torch . arange ( token_count )
token_indices = torch . arange ( IMAGE_TOKEN_COUNT )
if torch . cuda . is_available ( ) : token_indices = token_indices . cuda ( )
if torch . cuda . is_available ( ) : token_indices = token_indices . cuda ( )
self . token_indices = token_indices
self . token_indices = token_indices
@ -48,19 +51,13 @@ class DecoderSelfAttention(AttentionBase):
class DecoderLayer ( nn . Module ) :
class DecoderLayer ( nn . Module ) :
def __init__ (
def __init__ (
self ,
self ,
image_token_count : int ,
head_count : int ,
head_count : int ,
embed_count : int ,
embed_count : int ,
glu_embed_count : int
glu_embed_count : int
) :
) :
super ( ) . __init__ ( )
super ( ) . __init__ ( )
self . image_token_count = image_token_count
self . pre_self_attn_layer_norm = nn . LayerNorm ( embed_count )
self . pre_self_attn_layer_norm = nn . LayerNorm ( embed_count )
self . self_attn = DecoderSelfAttention (
self . self_attn = DecoderSelfAttention ( head_count , embed_count )
head_count ,
embed_count ,
image_token_count
)
self . self_attn_layer_norm = nn . LayerNorm ( embed_count )
self . self_attn_layer_norm = nn . LayerNorm ( embed_count )
self . pre_encoder_attn_layer_norm = nn . LayerNorm ( embed_count )
self . pre_encoder_attn_layer_norm = nn . LayerNorm ( embed_count )
self . encoder_attn = DecoderCrossAttention ( head_count , embed_count )
self . encoder_attn = DecoderCrossAttention ( head_count , embed_count )
@ -110,7 +107,6 @@ class DalleBartDecoder(nn.Module):
def __init__ (
def __init__ (
self ,
self ,
image_vocab_count : int ,
image_vocab_count : int ,
image_token_count : int ,
embed_count : int ,
embed_count : int ,
attention_head_count : int ,
attention_head_count : int ,
glu_embed_count : int ,
glu_embed_count : int ,
@ -121,12 +117,10 @@ class DalleBartDecoder(nn.Module):
self . layer_count = layer_count
self . layer_count = layer_count
self . embed_count = embed_count
self . embed_count = embed_count
self . condition_factor = 10.0
self . condition_factor = 10.0
self . image_token_count = image_token_count
self . embed_tokens = nn . Embedding ( image_vocab_count + 1 , embed_count )
self . embed_tokens = nn . Embedding ( image_vocab_count + 1 , embed_count )
self . embed_positions = nn . Embedding ( image_token_count , embed_count )
self . embed_positions = nn . Embedding ( IMAGE_TOKEN_COUNT , embed_count )
self . layers : List [ DecoderLayer ] = nn . ModuleList ( [
self . layers : List [ DecoderLayer ] = nn . ModuleList ( [
DecoderLayer (
DecoderLayer (
image_token_count ,
attention_head_count ,
attention_head_count ,
embed_count ,
embed_count ,
glu_embed_count
glu_embed_count
@ -137,7 +131,7 @@ class DalleBartDecoder(nn.Module):
self . final_ln = nn . LayerNorm ( embed_count )
self . final_ln = nn . LayerNorm ( embed_count )
self . lm_head = nn . Linear ( embed_count , image_vocab_count + 1 , bias = False )
self . lm_head = nn . Linear ( embed_count , image_vocab_count + 1 , bias = False )
self . zero_prob = torch . zeros ( [ 1 ] )
self . zero_prob = torch . zeros ( [ 1 ] )
self . token_indices = torch . arange ( self . image_token_count )
self . token_indices = torch . arange ( IMAGE_TOKEN_COUNT )
self . start_token = torch . tensor ( [ start_token ] ) . to ( torch . long )
self . start_token = torch . tensor ( [ start_token ] ) . to ( torch . long )
if torch . cuda . is_available ( ) :
if torch . cuda . is_available ( ) :
self . zero_prob = self . zero_prob . cuda ( )
self . zero_prob = self . zero_prob . cuda ( )
@ -183,13 +177,13 @@ class DalleBartDecoder(nn.Module):
torch . exp ( logits - top_logits [ : , [ 0 ] ] )
torch . exp ( logits - top_logits [ : , [ 0 ] ] )
)
)
return probs , attention_state
return probs , attention_state
def decode_row (
def decode_row (
self ,
self ,
row_index : int ,
row_index : int ,
attention_mask : BoolTensor ,
encoder_state : FloatTensor ,
encoder_state : FloatTensor ,
attention_mask : BoolTensor ,
attention_state : FloatTensor ,
attention_state : FloatTensor ,
image_tokens_sequence : LongTensor
image_tokens_sequence : LongTensor
) - > Tuple [ FloatTensor , LongTensor ] :
) - > Tuple [ FloatTensor , LongTensor ] :
@ -202,19 +196,18 @@ class DalleBartDecoder(nn.Module):
prev_tokens = image_tokens_sequence [ : , i ] ,
prev_tokens = image_tokens_sequence [ : , i ] ,
token_index = self . token_indices [ [ i ] ]
token_index = self . token_indices [ [ i ] ]
)
)
image_tokens_sequence [ : , i + 1 ] = torch . multinomial ( probs , 1 ) [ : , 0 ]
image_tokens_sequence [ : , i + 1 ] = torch . multinomial ( probs , 1 ) [ : , 0 ]
return attention_state , image_tokens_sequence
return attention_state , image_tokens_sequence
def forward (
def decode_initial (
self ,
self ,
seed : int ,
image_count : int ,
image_count : int ,
row_count : int ,
text_tokens : LongTensor ,
text_tokens : LongTensor ,
encoder_state : FloatTensor
encoder_state : FloatTensor
) - > LongTensor :
) - > Tuple [ FloatTensor , FloatTensor , FloatTensor , LongTensor ] :
expanded_indices = [ 0 ] * image_count + [ 1 ] * image_count
expanded_indices = [ 0 ] * image_count + [ 1 ] * image_count
text_tokens = text_tokens [ expanded_indices ]
text_tokens = text_tokens [ expanded_indices ]
encoder_state = encoder_state [ expanded_indices ]
encoder_state = encoder_state [ expanded_indices ]
@ -223,13 +216,13 @@ class DalleBartDecoder(nn.Module):
attention_state_shape = (
attention_state_shape = (
self . layer_count ,
self . layer_count ,
image_count * 4 ,
image_count * 4 ,
self . image_token_count ,
IMAGE_TOKEN_COUNT ,
self . embed_count
self . embed_count
)
)
attention_state = torch . zeros ( attention_state_shape )
attention_state = torch . zeros ( attention_state_shape )
image_tokens_sequence = torch . full (
image_tokens_sequence = torch . full (
( image_count , self . image_token_count + 1 ) ,
( image_count , IMAGE_TOKEN_COUNT + 1 ) ,
6965 , # black token
BLANK_TOKEN ,
dtype = torch . long
dtype = torch . long
)
)
if torch . cuda . is_available ( ) :
if torch . cuda . is_available ( ) :
@ -238,13 +231,6 @@ class DalleBartDecoder(nn.Module):
image_tokens_sequence [ : , 0 ] = self . start_token [ 0 ]
image_tokens_sequence [ : , 0 ] = self . start_token [ 0 ]
for row_index in range ( row_count ) :
if seed > 0 : torch . manual_seed ( seed )
attention_state , image_tokens_sequence = self . decode_row (
row_index ,
return encoder_state , attention_mask , attention_state , image_tokens_sequence
attention_mask ,
encoder_state ,
attention_state ,
image_tokens_sequence
)
return image_tokens_sequence [ : , 1 : ]