previous commit broke colab example, so adjusting flax requirement to 0.4.2 for now

This commit is contained in:
Brett Kuprel 2022-06-28 08:04:08 -04:00
parent 38ebe54a38
commit 34df2b97df
2 changed files with 5 additions and 6 deletions

View File

@ -35,7 +35,7 @@ class DecoderSelfAttentionFlax(AttentionFlax):
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
shape_split = decoder_state.shape[:2] + (self.head_count, -1) shape_split = decoder_state.shape[:2] + (self.head_count, -1)
keys_state = lax.dynamic_update_slice( keys_state = lax.dynamic_update_slice(
keys_state, keys_state,
self.k_proj(decoder_state).reshape(shape_split), self.k_proj(decoder_state).reshape(shape_split),
state_index state_index
) )
@ -205,7 +205,6 @@ class DalleBartDecoderFlax(nn.Module):
params: dict params: dict
) -> jnp.ndarray: ) -> jnp.ndarray:
attention_mask = jnp.not_equal(text_tokens, 1) attention_mask = jnp.not_equal(text_tokens, 1)
encoder_state = encoder_state.astype(jnp.float16)
def sample_next_image_token( def sample_next_image_token(
state: SampleState, state: SampleState,
@ -248,8 +247,8 @@ class DalleBartDecoderFlax(nn.Module):
initial_state = SampleState( initial_state = SampleState(
prev_token = self.start_token, prev_token = self.start_token,
prng_key = prng_key, prng_key = prng_key,
keys_state = jnp.zeros(state_shape, dtype=jnp.float16), keys_state = jnp.zeros(state_shape),
values_state = jnp.zeros(state_shape, dtype=jnp.float16) values_state = jnp.zeros(state_shape)
) )
_, image_tokens = lax.scan( _, image_tokens = lax.scan(

View File

@ -1,2 +1,2 @@
torch torch==0.4.2
flax flax