mega works with latest flax version 0.5.2 now, removing 0.4.2 pin
This commit is contained in:
@@ -33,7 +33,7 @@ class DecoderSelfAttentionFlax(AttentionFlax):
|
||||
|
||||
attention_state = lax.dynamic_update_slice(
|
||||
attention_state,
|
||||
jnp.concatenate([keys, values]),
|
||||
jnp.concatenate([keys, values]).astype(jnp.float32),
|
||||
state_index
|
||||
)
|
||||
batch_count = decoder_state.shape[0]
|
||||
@@ -44,7 +44,7 @@ class DecoderSelfAttentionFlax(AttentionFlax):
|
||||
values,
|
||||
queries,
|
||||
attention_mask
|
||||
)
|
||||
).astype(decoder_state.dtype)
|
||||
return decoder_state, attention_state
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user