simplified flax attention and matched torch attention

This commit is contained in:
Brett Kuprel 2022-06-29 14:56:28 -04:00
parent 61cc99c13c
commit d99828a239
4 changed files with 72 additions and 87 deletions

2
README.md vendored
View File

@ -4,7 +4,7 @@
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).
DALL·E Mega inference with PyTorch takes 7.3 seconds in Colab to generate an avocado armchair
It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab
### Setup

View File

@ -13,15 +13,9 @@ class DecoderCrossAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray,
) -> jnp.ndarray:
keys: jnp.ndarray = self.k_proj(encoder_state)
values: jnp.ndarray = self.v_proj(encoder_state)
queries: jnp.ndarray = self.q_proj(decoder_state)
query_shape = queries.shape[:2] + (self.head_count, -1)
key_value_shape = keys.shape[:2] + (self.head_count, -1)
keys = keys.reshape(key_value_shape)
values = values.reshape(key_value_shape)
queries = queries.reshape(query_shape)
queries /= queries.shape[-1] ** 0.5
keys = self.k_proj(encoder_state)
values = self.v_proj(encoder_state)
queries = self.q_proj(decoder_state)
return self.forward(keys, values, queries, attention_mask)
@ -29,31 +23,29 @@ class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(
self,
decoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
state_index: tuple
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
shape_split = decoder_state.shape[:2] + (self.head_count, -1)
keys_state = lax.dynamic_update_slice(
keys_state,
self.k_proj(decoder_state).reshape(shape_split),
) -> Tuple[jnp.ndarray, jnp.ndarray]:
keys = self.k_proj(decoder_state)
values = self.v_proj(decoder_state)
queries = self.q_proj(decoder_state)
attention_state = lax.dynamic_update_slice(
attention_state,
jnp.concatenate([keys, values]),
state_index
)
values_state = lax.dynamic_update_slice(
values_state,
self.v_proj(decoder_state).reshape(shape_split),
state_index
)
queries = self.q_proj(decoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
batch_count = decoder_state.shape[0]
keys, values = attention_state[:batch_count], attention_state[batch_count:]
decoder_state = self.forward(
keys_state,
values_state,
keys,
values,
queries,
attention_mask
)
return decoder_state, (keys_state, values_state)
return decoder_state, attention_state
class DalleBartDecoderLayerFlax(nn.Module):
@ -82,11 +74,10 @@ class DalleBartDecoderLayerFlax(nn.Module):
self,
decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
token_index: int
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
) -> Tuple[jnp.ndarray, jnp.ndarray]:
# Self Attention
residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
@ -94,12 +85,11 @@ class DalleBartDecoderLayerFlax(nn.Module):
jnp.arange(self.image_token_count) < token_index + 1,
(decoder_state.shape[0], 1)
)
decoder_state, keys_values_state = self.self_attn(
decoder_state, attention_state = self.self_attn(
decoder_state,
keys_state,
values_state,
attention_state,
self_attention_mask,
(0, token_index, 0, 0)
(0, token_index, 0)
)
decoder_state = self.self_attn_layer_norm(decoder_state)
decoder_state = residual + decoder_state
@ -120,15 +110,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
decoder_state = self.glu(decoder_state)
decoder_state = residual + decoder_state
return decoder_state, keys_values_state
return decoder_state, attention_state
@flax.struct.dataclass
class SampleState:
prev_token: jnp.ndarray
prng_key: jnp.ndarray
keys_state: jnp.ndarray
values_state: jnp.ndarray
attention_state: jnp.ndarray
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
return a * logits[0, -1] + (1 - a) * logits[1, -1]
@ -161,8 +150,8 @@ class DalleBartDecoderFlax(nn.Module):
DalleBartDecoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 },
split_rngs = { "params": True },
in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast),
out_axes = (0, 0),
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
out_axes = 0,
length=self.layer_count,
)(
self.image_token_count,
@ -178,28 +167,26 @@ class DalleBartDecoderFlax(nn.Module):
def __call__(
self,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
prev_token: int,
token_index: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
) -> Tuple[jnp.ndarray, jnp.ndarray]:
batch_count = encoder_state.shape[0]
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
decoder_state = self.embed_tokens(prev_token * ones)
decoder_state += self.embed_positions(token_index * ones)
decoder_state = self.layernorm_embedding(decoder_state)
decoder_state, (keys_state, values_state) = self.layers(
decoder_state, attention_state = self.layers(
decoder_state,
encoder_state,
keys_state,
values_state,
attention_state,
attention_mask,
token_index
)
decoder_state = self.final_ln(decoder_state)
decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state
return decoder_state, attention_state
def sample_image_tokens(
self,
@ -213,12 +200,11 @@ class DalleBartDecoderFlax(nn.Module):
def sample_next_image_token(
state: SampleState,
token_index: int
) -> Tuple[SampleState, None]:
logits, keys_state, values_state = self.apply(
) -> Tuple[SampleState, jnp.ndarray]:
logits, attention_state = self.apply(
{ 'params': params },
encoder_state = encoder_state,
keys_state = state.keys_state,
values_state = state.values_state,
attention_state = state.attention_state,
attention_mask = attention_mask,
prev_token = state.prev_token,
token_index = token_index
@ -233,26 +219,23 @@ class DalleBartDecoderFlax(nn.Module):
state = SampleState(
prev_token = next_token,
prng_key = prng_key_next,
keys_state = keys_state,
values_state = values_state
attention_state = attention_state
)
return state, next_token
batch_count = encoder_state.shape[0]
state_shape = (
attention_state_shape = (
self.layer_count,
batch_count,
batch_count * 2,
self.image_token_count,
self.attention_head_count,
self.embed_count // self.attention_head_count
self.embed_count
)
initial_state = SampleState(
prev_token = self.start_token,
prng_key = prng_key,
keys_state = jnp.zeros(state_shape),
values_state = jnp.zeros(state_shape)
attention_state = jnp.zeros(attention_state_shape)
)
_, image_tokens = lax.scan(

View File

@ -23,22 +23,22 @@ class DecoderSelfAttentionTorch(AttentionTorch):
def forward(
self,
decoder_state: FloatTensor,
keys_values: FloatTensor,
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_mask: BoolTensor
) -> Tuple[FloatTensor, FloatTensor]:
batch_count = decoder_state.shape[0]
keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state)
keys_values = torch.where(
attention_state = torch.where(
token_mask[None, :, None],
torch.cat([keys, values]),
keys_values
attention_state
)
keys, values = keys_values[:batch_count], keys_values[batch_count:]
batch_count = decoder_state.shape[0]
keys, values = attention_state[:batch_count], attention_state[batch_count:]
decoder_state = super().forward(keys, values, queries, attention_mask)
return decoder_state, keys_values
return decoder_state, attention_state
class DecoderLayerTorch(nn.Module):
@ -67,7 +67,7 @@ class DecoderLayerTorch(nn.Module):
self,
decoder_state: FloatTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]:
@ -77,9 +77,9 @@ class DecoderLayerTorch(nn.Module):
self_attn_mask = self.token_indices < token_index + 1
token_mask = self.token_indices == token_index
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
decoder_state, keys_values_state = self.self_attn.forward(
decoder_state, attention_state = self.self_attn.forward(
decoder_state,
keys_values_state,
attention_state,
self_attn_mask,
token_mask
)
@ -102,7 +102,7 @@ class DecoderLayerTorch(nn.Module):
decoder_state = self.glu.forward(decoder_state)
decoder_state = residual + decoder_state
return decoder_state, keys_values_state
return decoder_state, attention_state
class DalleBartDecoderTorch(nn.Module):
@ -139,8 +139,9 @@ class DalleBartDecoderTorch(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
self.keys_values_state_shape = (
layer_count * 2 * batch_count,
self.attention_state_shape = (
layer_count,
2 * batch_count,
image_token_count,
embed_count
)
@ -157,7 +158,7 @@ class DalleBartDecoderTorch(nn.Module):
self,
text_tokens: LongTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
attention_state: FloatTensor,
prev_token_and_index: LongTensor
) -> Tuple[LongTensor, FloatTensor]:
attention_mask = text_tokens.not_equal(1)
@ -168,17 +169,16 @@ class DalleBartDecoderTorch(nn.Module):
decoder_state += self.embed_positions.forward(token_index)
decoder_state = self.layernorm_embedding.forward(decoder_state)
decoder_state = decoder_state[:, None]
keys_values = []
for i, layer in enumerate(self.layers):
j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count
decoder_state, keys_values_layer = layer.forward(
attention_states_new = []
for i in range(self.layer_count):
decoder_state, attention_state_layer = self.layers[i].forward(
decoder_state,
encoder_state,
keys_values_state[j1:j2],
attention_state[i],
attention_mask,
token_index[:1]
)
keys_values.append(keys_values_layer)
attention_states_new.append(attention_state_layer)
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
a = self.condition_factor
@ -190,7 +190,7 @@ class DalleBartDecoderTorch(nn.Module):
self.zero_prob,
torch.exp(logits - top_logits[0])
)
return probs, torch.cat(keys_values)
return probs, torch.stack(attention_states_new)
def forward(
@ -199,17 +199,17 @@ class DalleBartDecoderTorch(nn.Module):
encoder_state: FloatTensor
) -> LongTensor:
image_tokens: List[LongTensor] = []
keys_values_state = torch.zeros(self.keys_values_state_shape)
attention_state = torch.zeros(self.attention_state_shape)
if torch.cuda.is_available():
keys_values_state = keys_values_state.cuda()
attention_state = attention_state.cuda()
image_token = self.start_token
for i in range(self.sample_token_count):
token_index = self.token_indices[i:i+1]
probs, keys_values_state = self.decode_step(
probs, attention_state = self.decode_step(
text_tokens = text_tokens,
encoder_state = encoder_state,
keys_values_state = keys_values_state,
attention_state = attention_state,
prev_token_and_index = torch.cat([image_token, token_index])
)

View File

@ -41,6 +41,10 @@ class AttentionFlax(nn.Module):
queries: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias: jnp.ndarray = lax.select(
attention_mask,
jnp.full(attention_mask.shape, 0.0),
@ -70,11 +74,9 @@ class EncoderSelfAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
keys = self.k_proj(encoder_state).reshape(shape_split)
values = self.v_proj(encoder_state).reshape(shape_split)
queries = self.q_proj(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
keys = self.k_proj(encoder_state)
values = self.v_proj(encoder_state)
queries = self.q_proj(encoder_state)
return self.forward(keys, values, queries, attention_mask)