@ -13,15 +13,9 @@ class DecoderCrossAttentionFlax(AttentionFlax):
encoder_state : jnp . ndarray ,
attention_mask : jnp . ndarray ,
) - > jnp . ndarray :
keys : jnp . ndarray = self . k_proj ( encoder_state )
values : jnp . ndarray = self . v_proj ( encoder_state )
queries : jnp . ndarray = self . q_proj ( decoder_state )
query_shape = queries . shape [ : 2 ] + ( self . head_count , - 1 )
key_value_shape = keys . shape [ : 2 ] + ( self . head_count , - 1 )
keys = keys . reshape ( key_value_shape )
values = values . reshape ( key_value_shape )
queries = queries . reshape ( query_shape )
queries / = queries . shape [ - 1 ] * * 0.5
keys = self . k_proj ( encoder_state )
values = self . v_proj ( encoder_state )
queries = self . q_proj ( decoder_state )
return self . forward ( keys , values , queries , attention_mask )
@ -29,31 +23,29 @@ class DecoderSelfAttentionFlax(AttentionFlax):
def __call__ (
self ,
decoder_state : jnp . ndarray ,
keys_state : jnp . ndarray ,
values_state : jnp . ndarray ,
attention_state : jnp . ndarray ,
attention_mask : jnp . ndarray ,
state_index : tuple
) - > Tuple [ jnp . ndarray , Tuple [ jnp . ndarray , jnp . ndarray ] ] :
shape_split = decoder_state . shape [ : 2 ] + ( self . head_count , - 1 )
keys_state = lax . dynamic_update_slice (
keys_state ,
self . k_proj ( decoder_state ) . reshape ( shape_split ) ,
) - > Tuple [ jnp . ndarray , jnp . ndarray ] :
keys = self . k_proj ( decoder_state )
values = self . v_proj ( decoder_state )
queries = self . q_proj ( decoder_state )
attention_state = lax . dynamic_update_slice (
attention_state ,
jnp . concatenate ( [ keys , values ] ) ,
state_index
)
values_state = lax . dynamic_update_slice (
values_state ,
self . v_proj ( decoder_state ) . reshape ( shape_split ) ,
state_index
)
queries = self . q_proj ( decoder_state ) . reshape ( shape_split )
queries / = queries . shape [ - 1 ] * * 0.5
batch_count = decoder_state . shape [ 0 ]
keys , values = attention_state [ : batch_count ] , attention_state [ batch_count : ]
decoder_state = self . forward (
keys_state ,
values_state ,
keys ,
values ,
queries ,
attention_mask
)
return decoder_state , ( keys_state , values_state )
return decoder_state , attention_state
class DalleBartDecoderLayerFlax ( nn . Module ) :
@ -82,11 +74,10 @@ class DalleBartDecoderLayerFlax(nn.Module):
self ,
decoder_state : jnp . ndarray ,
encoder_state : jnp . ndarray ,
keys_state : jnp . ndarray ,
values_state : jnp . ndarray ,
attention_state : jnp . ndarray ,
attention_mask : jnp . ndarray ,
token_index : int
) - > Tuple [ jnp . ndarray , Tuple [ jnp . ndarray , jnp . ndarray ] ] :
) - > Tuple [ jnp . ndarray , jnp . ndarray ] :
# Self Attention
residual = decoder_state
decoder_state = self . pre_self_attn_layer_norm ( decoder_state )
@ -94,12 +85,11 @@ class DalleBartDecoderLayerFlax(nn.Module):
jnp . arange ( self . image_token_count ) < token_index + 1 ,
( decoder_state . shape [ 0 ] , 1 )
)
decoder_state , keys_values _state = self . self_attn (
decoder_state , attention _state = self . self_attn (
decoder_state ,
keys_state ,
values_state ,
attention_state ,
self_attention_mask ,
( 0 , token_index , 0 , 0 )
( 0 , token_index , 0 )
)
decoder_state = self . self_attn_layer_norm ( decoder_state )
decoder_state = residual + decoder_state
@ -120,15 +110,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
decoder_state = self . glu ( decoder_state )
decoder_state = residual + decoder_state
return decoder_state , keys_values _state
return decoder_state , attention _state
@flax . struct . dataclass
class SampleState :
prev_token : jnp . ndarray
prng_key : jnp . ndarray
keys_state : jnp . ndarray
values_state : jnp . ndarray
attention_state : jnp . ndarray
def super_conditioned ( logits : jnp . ndarray , a : float ) - > jnp . ndarray :
return a * logits [ 0 , - 1 ] + ( 1 - a ) * logits [ 1 , - 1 ]
@ -161,8 +150,8 @@ class DalleBartDecoderFlax(nn.Module):
DalleBartDecoderLayerFlax ,
variable_axes = { " params " : 0 , " cache " : 0 } ,
split_rngs = { " params " : True } ,
in_axes = ( nn . broadcast , 0 , 0 , nn . broadcast , nn . broadcast ) ,
out_axes = ( 0 , 0 ) ,
in_axes = ( nn . broadcast , 0 , nn . broadcast , nn . broadcast ) ,
out_axes = 0 ,
length = self . layer_count ,
) (
self . image_token_count ,
@ -178,28 +167,26 @@ class DalleBartDecoderFlax(nn.Module):
def __call__ (
self ,
encoder_state : jnp . ndarray ,
keys_state : jnp . ndarray ,
values_state : jnp . ndarray ,
attention_state : jnp . ndarray ,
attention_mask : jnp . ndarray ,
prev_token : int ,
token_index : int
) - > Tuple [ jnp . ndarray , jnp . ndarray , jnp . ndarray ] :
) - > Tuple [ jnp . ndarray , jnp . ndarray ] :
batch_count = encoder_state . shape [ 0 ]
ones = jnp . ones ( ( batch_count , 1 ) , dtype = jnp . int32 )
decoder_state = self . embed_tokens ( prev_token * ones )
decoder_state + = self . embed_positions ( token_index * ones )
decoder_state = self . layernorm_embedding ( decoder_state )
decoder_state , ( keys_state , values_state ) = self . layers (
decoder_state , attention_state = self . layers (
decoder_state ,
encoder_state ,
keys_state ,
values_state ,
attention_state ,
attention_mask ,
token_index
)
decoder_state = self . final_ln ( decoder_state )
decoder_state = self . lm_head ( decoder_state )
return decoder_state , keys_state , values _state
return decoder_state , attention _state
def sample_image_tokens (
self ,
@ -213,12 +200,11 @@ class DalleBartDecoderFlax(nn.Module):
def sample_next_image_token (
state : SampleState ,
token_index : int
) - > Tuple [ SampleState , None ] :
logits , keys_state , values _state = self . apply (
) - > Tuple [ SampleState , jnp . ndarray ] :
logits , attention _state = self . apply (
{ ' params ' : params } ,
encoder_state = encoder_state ,
keys_state = state . keys_state ,
values_state = state . values_state ,
attention_state = state . attention_state ,
attention_mask = attention_mask ,
prev_token = state . prev_token ,
token_index = token_index
@ -233,26 +219,23 @@ class DalleBartDecoderFlax(nn.Module):
state = SampleState (
prev_token = next_token ,
prng_key = prng_key_next ,
keys_state = keys_state ,
values_state = values_state
attention_state = attention_state
)
return state , next_token
batch_count = encoder_state . shape [ 0 ]
state_shape = (
attention_ state_shape = (
self . layer_count ,
batch_count ,
batch_count * 2 ,
self . image_token_count ,
self . attention_head_count ,
self . embed_count / / self . attention_head_count
self . embed_count
)
initial_state = SampleState (
prev_token = self . start_token ,
prng_key = prng_key ,
keys_state = jnp . zeros ( state_shape ) ,
values_state = jnp . zeros ( state_shape )
attention_state = jnp . zeros ( attention_state_shape )
)
_ , image_tokens = lax . scan (