|
|
|
@ -35,7 +35,7 @@ class DecoderSelfAttentionFlax(AttentionFlax): |
|
|
|
|
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: |
|
|
|
|
shape_split = decoder_state.shape[:2] + (self.head_count, -1) |
|
|
|
|
keys_state = lax.dynamic_update_slice( |
|
|
|
|
keys_state, |
|
|
|
|
keys_state, |
|
|
|
|
self.k_proj(decoder_state).reshape(shape_split), |
|
|
|
|
state_index |
|
|
|
|
) |
|
|
|
@ -205,7 +205,6 @@ class DalleBartDecoderFlax(nn.Module): |
|
|
|
|
params: dict |
|
|
|
|
) -> jnp.ndarray: |
|
|
|
|
attention_mask = jnp.not_equal(text_tokens, 1) |
|
|
|
|
encoder_state = encoder_state.astype(jnp.float16) |
|
|
|
|
|
|
|
|
|
def sample_next_image_token( |
|
|
|
|
state: SampleState, |
|
|
|
@ -248,8 +247,8 @@ class DalleBartDecoderFlax(nn.Module): |
|
|
|
|
initial_state = SampleState( |
|
|
|
|
prev_token = self.start_token, |
|
|
|
|
prng_key = prng_key, |
|
|
|
|
keys_state = jnp.zeros(state_shape, dtype=jnp.float16), |
|
|
|
|
values_state = jnp.zeros(state_shape, dtype=jnp.float16) |
|
|
|
|
keys_state = jnp.zeros(state_shape), |
|
|
|
|
values_state = jnp.zeros(state_shape) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
_, image_tokens = lax.scan( |
|
|
|
|