works with latest flax version 0.5.2 now

main
Brett Kuprel 2 years ago
parent 50e1e74c4a
commit 38ebe54a38
  1. 7
      min_dalle/models/dalle_bart_decoder_flax.py

@ -35,7 +35,7 @@ class DecoderSelfAttentionFlax(AttentionFlax):
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
shape_split = decoder_state.shape[:2] + (self.head_count, -1)
keys_state = lax.dynamic_update_slice(
keys_state,
keys_state,
self.k_proj(decoder_state).reshape(shape_split),
state_index
)
@ -205,6 +205,7 @@ class DalleBartDecoderFlax(nn.Module):
params: dict
) -> jnp.ndarray:
attention_mask = jnp.not_equal(text_tokens, 1)
encoder_state = encoder_state.astype(jnp.float16)
def sample_next_image_token(
state: SampleState,
@ -247,8 +248,8 @@ class DalleBartDecoderFlax(nn.Module):
initial_state = SampleState(
prev_token = self.start_token,
prng_key = prng_key,
keys_state = jnp.zeros(state_shape),
values_state = jnp.zeros(state_shape)
keys_state = jnp.zeros(state_shape, dtype=jnp.float16),
values_state = jnp.zeros(state_shape, dtype=jnp.float16)
)
_, image_tokens = lax.scan(

Loading…
Cancel
Save