diff --git a/min_dalle/models/dalle_bart_decoder_flax.py b/min_dalle/models/dalle_bart_decoder_flax.py index caf28ec..d2bc688 100644 --- a/min_dalle/models/dalle_bart_decoder_flax.py +++ b/min_dalle/models/dalle_bart_decoder_flax.py @@ -35,7 +35,7 @@ class DecoderSelfAttentionFlax(AttentionFlax): ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: shape_split = decoder_state.shape[:2] + (self.head_count, -1) keys_state = lax.dynamic_update_slice( - keys_state, + keys_state, self.k_proj(decoder_state).reshape(shape_split), state_index ) @@ -205,6 +205,7 @@ class DalleBartDecoderFlax(nn.Module): params: dict ) -> jnp.ndarray: attention_mask = jnp.not_equal(text_tokens, 1) + encoder_state = encoder_state.astype(jnp.float16) def sample_next_image_token( state: SampleState, @@ -247,8 +248,8 @@ class DalleBartDecoderFlax(nn.Module): initial_state = SampleState( prev_token = self.start_token, prng_key = prng_key, - keys_state = jnp.zeros(state_shape), - values_state = jnp.zeros(state_shape) + keys_state = jnp.zeros(state_shape, dtype=jnp.float16), + values_state = jnp.zeros(state_shape, dtype=jnp.float16) ) _, image_tokens = lax.scan(