works with latest flax version 0.5.2 now

This commit is contained in:
Brett Kuprel 2022-06-28 07:12:29 -04:00
parent 50e1e74c4a
commit 38ebe54a38

View File

@ -205,6 +205,7 @@ class DalleBartDecoderFlax(nn.Module):
params: dict
) -> jnp.ndarray:
attention_mask = jnp.not_equal(text_tokens, 1)
encoder_state = encoder_state.astype(jnp.float16)
def sample_next_image_token(
state: SampleState,
@ -247,8 +248,8 @@ class DalleBartDecoderFlax(nn.Module):
initial_state = SampleState(
prev_token = self.start_token,
prng_key = prng_key,
keys_state = jnp.zeros(state_shape),
values_state = jnp.zeros(state_shape)
keys_state = jnp.zeros(state_shape, dtype=jnp.float16),
values_state = jnp.zeros(state_shape, dtype=jnp.float16)
)
_, image_tokens = lax.scan(