previous commit broke colab example, so adjusting flax requirement to 0.4.2 for now
This commit is contained in:
parent
38ebe54a38
commit
34df2b97df
|
@ -35,7 +35,7 @@ class DecoderSelfAttentionFlax(AttentionFlax):
|
||||||
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
||||||
shape_split = decoder_state.shape[:2] + (self.head_count, -1)
|
shape_split = decoder_state.shape[:2] + (self.head_count, -1)
|
||||||
keys_state = lax.dynamic_update_slice(
|
keys_state = lax.dynamic_update_slice(
|
||||||
keys_state,
|
keys_state,
|
||||||
self.k_proj(decoder_state).reshape(shape_split),
|
self.k_proj(decoder_state).reshape(shape_split),
|
||||||
state_index
|
state_index
|
||||||
)
|
)
|
||||||
|
@ -205,7 +205,6 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
params: dict
|
params: dict
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
attention_mask = jnp.not_equal(text_tokens, 1)
|
attention_mask = jnp.not_equal(text_tokens, 1)
|
||||||
encoder_state = encoder_state.astype(jnp.float16)
|
|
||||||
|
|
||||||
def sample_next_image_token(
|
def sample_next_image_token(
|
||||||
state: SampleState,
|
state: SampleState,
|
||||||
|
@ -248,8 +247,8 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
initial_state = SampleState(
|
initial_state = SampleState(
|
||||||
prev_token = self.start_token,
|
prev_token = self.start_token,
|
||||||
prng_key = prng_key,
|
prng_key = prng_key,
|
||||||
keys_state = jnp.zeros(state_shape, dtype=jnp.float16),
|
keys_state = jnp.zeros(state_shape),
|
||||||
values_state = jnp.zeros(state_shape, dtype=jnp.float16)
|
values_state = jnp.zeros(state_shape)
|
||||||
)
|
)
|
||||||
|
|
||||||
_, image_tokens = lax.scan(
|
_, image_tokens = lax.scan(
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
torch
|
torch==0.4.2
|
||||||
flax
|
flax
|
||||||
|
|
Loading…
Reference in New Issue
Block a user