mega works with latest flax version 0.5.2 now, removing 0.4.2 pin

This commit is contained in:
Brett Kuprel
2022-07-01 02:58:43 -04:00
parent eaee59a1ef
commit b40fd83a0d
5 changed files with 25 additions and 27 deletions

View File

@@ -33,7 +33,7 @@ class DecoderSelfAttentionFlax(AttentionFlax):
attention_state = lax.dynamic_update_slice(
attention_state,
jnp.concatenate([keys, values]),
jnp.concatenate([keys, values]).astype(jnp.float32),
state_index
)
batch_count = decoder_state.shape[0]
@@ -44,7 +44,7 @@ class DecoderSelfAttentionFlax(AttentionFlax):
values,
queries,
attention_mask
)
).astype(decoder_state.dtype)
return decoder_state, attention_state