From 38ebe54a382f36dc7090b69f632d355310108701 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Tue, 28 Jun 2022 07:12:29 -0400 Subject: [PATCH] works with latest flax version 0.5.2 now --- min_dalle/models/dalle_bart_decoder_flax.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/min_dalle/models/dalle_bart_decoder_flax.py b/min_dalle/models/dalle_bart_decoder_flax.py index caf28ec..d2bc688 100644 --- a/min_dalle/models/dalle_bart_decoder_flax.py +++ b/min_dalle/models/dalle_bart_decoder_flax.py @@ -35,7 +35,7 @@ class DecoderSelfAttentionFlax(AttentionFlax): ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: shape_split = decoder_state.shape[:2] + (self.head_count, -1) keys_state = lax.dynamic_update_slice( - keys_state, + keys_state, self.k_proj(decoder_state).reshape(shape_split), state_index ) @@ -205,6 +205,7 @@ class DalleBartDecoderFlax(nn.Module): params: dict ) -> jnp.ndarray: attention_mask = jnp.not_equal(text_tokens, 1) + encoder_state = encoder_state.astype(jnp.float16) def sample_next_image_token( state: SampleState, @@ -247,8 +248,8 @@ class DalleBartDecoderFlax(nn.Module): initial_state = SampleState( prev_token = self.start_token, prng_key = prng_key, - keys_state = jnp.zeros(state_shape), - values_state = jnp.zeros(state_shape) + keys_state = jnp.zeros(state_shape, dtype=jnp.float16), + values_state = jnp.zeros(state_shape, dtype=jnp.float16) ) _, image_tokens = lax.scan(