diff --git a/min_dalle/models/dalle_bart_decoder_flax.py b/min_dalle/models/dalle_bart_decoder_flax.py index d2bc688..caf28ec 100644 --- a/min_dalle/models/dalle_bart_decoder_flax.py +++ b/min_dalle/models/dalle_bart_decoder_flax.py @@ -35,7 +35,7 @@ class DecoderSelfAttentionFlax(AttentionFlax): ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: shape_split = decoder_state.shape[:2] + (self.head_count, -1) keys_state = lax.dynamic_update_slice( - keys_state, + keys_state, self.k_proj(decoder_state).reshape(shape_split), state_index ) @@ -205,7 +205,6 @@ class DalleBartDecoderFlax(nn.Module): params: dict ) -> jnp.ndarray: attention_mask = jnp.not_equal(text_tokens, 1) - encoder_state = encoder_state.astype(jnp.float16) def sample_next_image_token( state: SampleState, @@ -248,8 +247,8 @@ class DalleBartDecoderFlax(nn.Module): initial_state = SampleState( prev_token = self.start_token, prng_key = prng_key, - keys_state = jnp.zeros(state_shape, dtype=jnp.float16), - values_state = jnp.zeros(state_shape, dtype=jnp.float16) + keys_state = jnp.zeros(state_shape), + values_state = jnp.zeros(state_shape) ) _, image_tokens = lax.scan( diff --git a/requirements.txt b/requirements.txt index 165595e..d60c4e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch -flax \ No newline at end of file +torch==0.4.2 +flax