simplified flax attention and matched torch attention

main
Brett Kuprel 2 years ago
parent 61cc99c13c
commit d99828a239
  1. 2
      README.md
  2. 97
      min_dalle/models/dalle_bart_decoder_flax.py
  3. 48
      min_dalle/models/dalle_bart_decoder_torch.py
  4. 12
      min_dalle/models/dalle_bart_encoder_flax.py

2
README.md vendored

@ -4,7 +4,7 @@
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).
DALL·E Mega inference with PyTorch takes 7.3 seconds in Colab to generate an avocado armchair
It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab
### Setup

@ -13,15 +13,9 @@ class DecoderCrossAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray,
) -> jnp.ndarray:
keys: jnp.ndarray = self.k_proj(encoder_state)
values: jnp.ndarray = self.v_proj(encoder_state)
queries: jnp.ndarray = self.q_proj(decoder_state)
query_shape = queries.shape[:2] + (self.head_count, -1)
key_value_shape = keys.shape[:2] + (self.head_count, -1)
keys = keys.reshape(key_value_shape)
values = values.reshape(key_value_shape)
queries = queries.reshape(query_shape)
queries /= queries.shape[-1] ** 0.5
keys = self.k_proj(encoder_state)
values = self.v_proj(encoder_state)
queries = self.q_proj(decoder_state)
return self.forward(keys, values, queries, attention_mask)
@ -29,31 +23,29 @@ class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(
self,
decoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
state_index: tuple
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
shape_split = decoder_state.shape[:2] + (self.head_count, -1)
keys_state = lax.dynamic_update_slice(
keys_state,
self.k_proj(decoder_state).reshape(shape_split),
) -> Tuple[jnp.ndarray, jnp.ndarray]:
keys = self.k_proj(decoder_state)
values = self.v_proj(decoder_state)
queries = self.q_proj(decoder_state)
attention_state = lax.dynamic_update_slice(
attention_state,
jnp.concatenate([keys, values]),
state_index
)
values_state = lax.dynamic_update_slice(
values_state,
self.v_proj(decoder_state).reshape(shape_split),
state_index
)
queries = self.q_proj(decoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
batch_count = decoder_state.shape[0]
keys, values = attention_state[:batch_count], attention_state[batch_count:]
decoder_state = self.forward(
keys_state,
values_state,
keys,
values,
queries,
attention_mask
)
return decoder_state, (keys_state, values_state)
return decoder_state, attention_state
class DalleBartDecoderLayerFlax(nn.Module):
@ -82,11 +74,10 @@ class DalleBartDecoderLayerFlax(nn.Module):
self,
decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
token_index: int
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
) -> Tuple[jnp.ndarray, jnp.ndarray]:
# Self Attention
residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
@ -94,12 +85,11 @@ class DalleBartDecoderLayerFlax(nn.Module):
jnp.arange(self.image_token_count) < token_index + 1,
(decoder_state.shape[0], 1)
)
decoder_state, keys_values_state = self.self_attn(
decoder_state, attention_state = self.self_attn(
decoder_state,
keys_state,
values_state,
attention_state,
self_attention_mask,
(0, token_index, 0, 0)
(0, token_index, 0)
)
decoder_state = self.self_attn_layer_norm(decoder_state)
decoder_state = residual + decoder_state
@ -120,15 +110,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
decoder_state = self.glu(decoder_state)
decoder_state = residual + decoder_state
return decoder_state, keys_values_state
return decoder_state, attention_state
@flax.struct.dataclass
class SampleState:
prev_token: jnp.ndarray
prng_key: jnp.ndarray
keys_state: jnp.ndarray
values_state: jnp.ndarray
attention_state: jnp.ndarray
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
return a * logits[0, -1] + (1 - a) * logits[1, -1]
@ -161,8 +150,8 @@ class DalleBartDecoderFlax(nn.Module):
DalleBartDecoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 },
split_rngs = { "params": True },
in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast),
out_axes = (0, 0),
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
out_axes = 0,
length=self.layer_count,
)(
self.image_token_count,
@ -178,28 +167,26 @@ class DalleBartDecoderFlax(nn.Module):
def __call__(
self,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
prev_token: int,
token_index: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
) -> Tuple[jnp.ndarray, jnp.ndarray]:
batch_count = encoder_state.shape[0]
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
decoder_state = self.embed_tokens(prev_token * ones)
decoder_state += self.embed_positions(token_index * ones)
decoder_state = self.layernorm_embedding(decoder_state)
decoder_state, (keys_state, values_state) = self.layers(
decoder_state, attention_state = self.layers(
decoder_state,
encoder_state,
keys_state,
values_state,
attention_state,
attention_mask,
token_index
)
decoder_state = self.final_ln(decoder_state)
decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state
return decoder_state, attention_state
def sample_image_tokens(
self,
@ -213,12 +200,11 @@ class DalleBartDecoderFlax(nn.Module):
def sample_next_image_token(
state: SampleState,
token_index: int
) -> Tuple[SampleState, None]:
logits, keys_state, values_state = self.apply(
) -> Tuple[SampleState, jnp.ndarray]:
logits, attention_state = self.apply(
{ 'params': params },
encoder_state = encoder_state,
keys_state = state.keys_state,
values_state = state.values_state,
attention_state = state.attention_state,
attention_mask = attention_mask,
prev_token = state.prev_token,
token_index = token_index
@ -233,26 +219,23 @@ class DalleBartDecoderFlax(nn.Module):
state = SampleState(
prev_token = next_token,
prng_key = prng_key_next,
keys_state = keys_state,
values_state = values_state
attention_state = attention_state
)
return state, next_token
batch_count = encoder_state.shape[0]
state_shape = (
attention_state_shape = (
self.layer_count,
batch_count,
batch_count * 2,
self.image_token_count,
self.attention_head_count,
self.embed_count // self.attention_head_count
self.embed_count
)
initial_state = SampleState(
prev_token = self.start_token,
prng_key = prng_key,
keys_state = jnp.zeros(state_shape),
values_state = jnp.zeros(state_shape)
attention_state = jnp.zeros(attention_state_shape)
)
_, image_tokens = lax.scan(

@ -23,22 +23,22 @@ class DecoderSelfAttentionTorch(AttentionTorch):
def forward(
self,
decoder_state: FloatTensor,
keys_values: FloatTensor,
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_mask: BoolTensor
) -> Tuple[FloatTensor, FloatTensor]:
batch_count = decoder_state.shape[0]
keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state)
keys_values = torch.where(
attention_state = torch.where(
token_mask[None, :, None],
torch.cat([keys, values]),
keys_values
attention_state
)
keys, values = keys_values[:batch_count], keys_values[batch_count:]
batch_count = decoder_state.shape[0]
keys, values = attention_state[:batch_count], attention_state[batch_count:]
decoder_state = super().forward(keys, values, queries, attention_mask)
return decoder_state, keys_values
return decoder_state, attention_state
class DecoderLayerTorch(nn.Module):
@ -67,7 +67,7 @@ class DecoderLayerTorch(nn.Module):
self,
decoder_state: FloatTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]:
@ -77,9 +77,9 @@ class DecoderLayerTorch(nn.Module):
self_attn_mask = self.token_indices < token_index + 1
token_mask = self.token_indices == token_index
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
decoder_state, keys_values_state = self.self_attn.forward(
decoder_state, attention_state = self.self_attn.forward(
decoder_state,
keys_values_state,
attention_state,
self_attn_mask,
token_mask
)
@ -102,7 +102,7 @@ class DecoderLayerTorch(nn.Module):
decoder_state = self.glu.forward(decoder_state)
decoder_state = residual + decoder_state
return decoder_state, keys_values_state
return decoder_state, attention_state
class DalleBartDecoderTorch(nn.Module):
@ -139,8 +139,9 @@ class DalleBartDecoderTorch(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
self.keys_values_state_shape = (
layer_count * 2 * batch_count,
self.attention_state_shape = (
layer_count,
2 * batch_count,
image_token_count,
embed_count
)
@ -157,7 +158,7 @@ class DalleBartDecoderTorch(nn.Module):
self,
text_tokens: LongTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
attention_state: FloatTensor,
prev_token_and_index: LongTensor
) -> Tuple[LongTensor, FloatTensor]:
attention_mask = text_tokens.not_equal(1)
@ -168,17 +169,16 @@ class DalleBartDecoderTorch(nn.Module):
decoder_state += self.embed_positions.forward(token_index)
decoder_state = self.layernorm_embedding.forward(decoder_state)
decoder_state = decoder_state[:, None]
keys_values = []
for i, layer in enumerate(self.layers):
j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count
decoder_state, keys_values_layer = layer.forward(
attention_states_new = []
for i in range(self.layer_count):
decoder_state, attention_state_layer = self.layers[i].forward(
decoder_state,
encoder_state,
keys_values_state[j1:j2],
attention_state[i],
attention_mask,
token_index[:1]
)
keys_values.append(keys_values_layer)
attention_states_new.append(attention_state_layer)
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
a = self.condition_factor
@ -190,7 +190,7 @@ class DalleBartDecoderTorch(nn.Module):
self.zero_prob,
torch.exp(logits - top_logits[0])
)
return probs, torch.cat(keys_values)
return probs, torch.stack(attention_states_new)
def forward(
@ -199,17 +199,17 @@ class DalleBartDecoderTorch(nn.Module):
encoder_state: FloatTensor
) -> LongTensor:
image_tokens: List[LongTensor] = []
keys_values_state = torch.zeros(self.keys_values_state_shape)
attention_state = torch.zeros(self.attention_state_shape)
if torch.cuda.is_available():
keys_values_state = keys_values_state.cuda()
attention_state = attention_state.cuda()
image_token = self.start_token
for i in range(self.sample_token_count):
token_index = self.token_indices[i:i+1]
probs, keys_values_state = self.decode_step(
probs, attention_state = self.decode_step(
text_tokens = text_tokens,
encoder_state = encoder_state,
keys_values_state = keys_values_state,
attention_state = attention_state,
prev_token_and_index = torch.cat([image_token, token_index])
)

@ -41,6 +41,10 @@ class AttentionFlax(nn.Module):
queries: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias: jnp.ndarray = lax.select(
attention_mask,
jnp.full(attention_mask.shape, 0.0),
@ -70,11 +74,9 @@ class EncoderSelfAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
keys = self.k_proj(encoder_state).reshape(shape_split)
values = self.v_proj(encoder_state).reshape(shape_split)
queries = self.q_proj(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
keys = self.k_proj(encoder_state)
values = self.v_proj(encoder_state)
queries = self.q_proj(encoder_state)
return self.forward(keys, values, queries, attention_mask)

Loading…
Cancel
Save