works with latest flax version 0.5.2 now
This commit is contained in:
@@ -205,6 +205,7 @@ class DalleBartDecoderFlax(nn.Module):
|
||||
params: dict
|
||||
) -> jnp.ndarray:
|
||||
attention_mask = jnp.not_equal(text_tokens, 1)
|
||||
encoder_state = encoder_state.astype(jnp.float16)
|
||||
|
||||
def sample_next_image_token(
|
||||
state: SampleState,
|
||||
@@ -247,8 +248,8 @@ class DalleBartDecoderFlax(nn.Module):
|
||||
initial_state = SampleState(
|
||||
prev_token = self.start_token,
|
||||
prng_key = prng_key,
|
||||
keys_state = jnp.zeros(state_shape),
|
||||
values_state = jnp.zeros(state_shape)
|
||||
keys_state = jnp.zeros(state_shape, dtype=jnp.float16),
|
||||
values_state = jnp.zeros(state_shape, dtype=jnp.float16)
|
||||
)
|
||||
|
||||
_, image_tokens = lax.scan(
|
||||
|
Reference in New Issue
Block a user