simplified flax attention and matched torch attention

This commit is contained in:
Brett Kuprel 2022-06-29 14:56:28 -04:00
parent 61cc99c13c
commit d99828a239
4 changed files with 72 additions and 87 deletions

2
README.md vendored
View File

@ -4,7 +4,7 @@
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models). This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).
DALL·E Mega inference with PyTorch takes 7.3 seconds in Colab to generate an avocado armchair It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab
### Setup ### Setup

View File

@ -13,15 +13,9 @@ class DecoderCrossAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray, attention_mask: jnp.ndarray,
) -> jnp.ndarray: ) -> jnp.ndarray:
keys: jnp.ndarray = self.k_proj(encoder_state) keys = self.k_proj(encoder_state)
values: jnp.ndarray = self.v_proj(encoder_state) values = self.v_proj(encoder_state)
queries: jnp.ndarray = self.q_proj(decoder_state) queries = self.q_proj(decoder_state)
query_shape = queries.shape[:2] + (self.head_count, -1)
key_value_shape = keys.shape[:2] + (self.head_count, -1)
keys = keys.reshape(key_value_shape)
values = values.reshape(key_value_shape)
queries = queries.reshape(query_shape)
queries /= queries.shape[-1] ** 0.5
return self.forward(keys, values, queries, attention_mask) return self.forward(keys, values, queries, attention_mask)
@ -29,31 +23,29 @@ class DecoderSelfAttentionFlax(AttentionFlax):
def __call__( def __call__(
self, self,
decoder_state: jnp.ndarray, decoder_state: jnp.ndarray,
keys_state: jnp.ndarray, attention_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_mask: jnp.ndarray, attention_mask: jnp.ndarray,
state_index: tuple state_index: tuple
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: ) -> Tuple[jnp.ndarray, jnp.ndarray]:
shape_split = decoder_state.shape[:2] + (self.head_count, -1) keys = self.k_proj(decoder_state)
keys_state = lax.dynamic_update_slice( values = self.v_proj(decoder_state)
keys_state, queries = self.q_proj(decoder_state)
self.k_proj(decoder_state).reshape(shape_split),
attention_state = lax.dynamic_update_slice(
attention_state,
jnp.concatenate([keys, values]),
state_index state_index
) )
values_state = lax.dynamic_update_slice( batch_count = decoder_state.shape[0]
values_state, keys, values = attention_state[:batch_count], attention_state[batch_count:]
self.v_proj(decoder_state).reshape(shape_split),
state_index
)
queries = self.q_proj(decoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
decoder_state = self.forward( decoder_state = self.forward(
keys_state, keys,
values_state, values,
queries, queries,
attention_mask attention_mask
) )
return decoder_state, (keys_state, values_state) return decoder_state, attention_state
class DalleBartDecoderLayerFlax(nn.Module): class DalleBartDecoderLayerFlax(nn.Module):
@ -82,11 +74,10 @@ class DalleBartDecoderLayerFlax(nn.Module):
self, self,
decoder_state: jnp.ndarray, decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
keys_state: jnp.ndarray, attention_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_mask: jnp.ndarray, attention_mask: jnp.ndarray,
token_index: int token_index: int
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: ) -> Tuple[jnp.ndarray, jnp.ndarray]:
# Self Attention # Self Attention
residual = decoder_state residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm(decoder_state) decoder_state = self.pre_self_attn_layer_norm(decoder_state)
@ -94,12 +85,11 @@ class DalleBartDecoderLayerFlax(nn.Module):
jnp.arange(self.image_token_count) < token_index + 1, jnp.arange(self.image_token_count) < token_index + 1,
(decoder_state.shape[0], 1) (decoder_state.shape[0], 1)
) )
decoder_state, keys_values_state = self.self_attn( decoder_state, attention_state = self.self_attn(
decoder_state, decoder_state,
keys_state, attention_state,
values_state,
self_attention_mask, self_attention_mask,
(0, token_index, 0, 0) (0, token_index, 0)
) )
decoder_state = self.self_attn_layer_norm(decoder_state) decoder_state = self.self_attn_layer_norm(decoder_state)
decoder_state = residual + decoder_state decoder_state = residual + decoder_state
@ -120,15 +110,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
decoder_state = self.glu(decoder_state) decoder_state = self.glu(decoder_state)
decoder_state = residual + decoder_state decoder_state = residual + decoder_state
return decoder_state, keys_values_state return decoder_state, attention_state
@flax.struct.dataclass @flax.struct.dataclass
class SampleState: class SampleState:
prev_token: jnp.ndarray prev_token: jnp.ndarray
prng_key: jnp.ndarray prng_key: jnp.ndarray
keys_state: jnp.ndarray attention_state: jnp.ndarray
values_state: jnp.ndarray
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray: def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
return a * logits[0, -1] + (1 - a) * logits[1, -1] return a * logits[0, -1] + (1 - a) * logits[1, -1]
@ -161,8 +150,8 @@ class DalleBartDecoderFlax(nn.Module):
DalleBartDecoderLayerFlax, DalleBartDecoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 }, variable_axes = { "params": 0, "cache": 0 },
split_rngs = { "params": True }, split_rngs = { "params": True },
in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast), in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
out_axes = (0, 0), out_axes = 0,
length=self.layer_count, length=self.layer_count,
)( )(
self.image_token_count, self.image_token_count,
@ -178,28 +167,26 @@ class DalleBartDecoderFlax(nn.Module):
def __call__( def __call__(
self, self,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
keys_state: jnp.ndarray, attention_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_mask: jnp.ndarray, attention_mask: jnp.ndarray,
prev_token: int, prev_token: int,
token_index: int token_index: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ) -> Tuple[jnp.ndarray, jnp.ndarray]:
batch_count = encoder_state.shape[0] batch_count = encoder_state.shape[0]
ones = jnp.ones((batch_count, 1), dtype=jnp.int32) ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
decoder_state = self.embed_tokens(prev_token * ones) decoder_state = self.embed_tokens(prev_token * ones)
decoder_state += self.embed_positions(token_index * ones) decoder_state += self.embed_positions(token_index * ones)
decoder_state = self.layernorm_embedding(decoder_state) decoder_state = self.layernorm_embedding(decoder_state)
decoder_state, (keys_state, values_state) = self.layers( decoder_state, attention_state = self.layers(
decoder_state, decoder_state,
encoder_state, encoder_state,
keys_state, attention_state,
values_state,
attention_mask, attention_mask,
token_index token_index
) )
decoder_state = self.final_ln(decoder_state) decoder_state = self.final_ln(decoder_state)
decoder_state = self.lm_head(decoder_state) decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state return decoder_state, attention_state
def sample_image_tokens( def sample_image_tokens(
self, self,
@ -213,12 +200,11 @@ class DalleBartDecoderFlax(nn.Module):
def sample_next_image_token( def sample_next_image_token(
state: SampleState, state: SampleState,
token_index: int token_index: int
) -> Tuple[SampleState, None]: ) -> Tuple[SampleState, jnp.ndarray]:
logits, keys_state, values_state = self.apply( logits, attention_state = self.apply(
{ 'params': params }, { 'params': params },
encoder_state = encoder_state, encoder_state = encoder_state,
keys_state = state.keys_state, attention_state = state.attention_state,
values_state = state.values_state,
attention_mask = attention_mask, attention_mask = attention_mask,
prev_token = state.prev_token, prev_token = state.prev_token,
token_index = token_index token_index = token_index
@ -233,26 +219,23 @@ class DalleBartDecoderFlax(nn.Module):
state = SampleState( state = SampleState(
prev_token = next_token, prev_token = next_token,
prng_key = prng_key_next, prng_key = prng_key_next,
keys_state = keys_state, attention_state = attention_state
values_state = values_state
) )
return state, next_token return state, next_token
batch_count = encoder_state.shape[0] batch_count = encoder_state.shape[0]
state_shape = ( attention_state_shape = (
self.layer_count, self.layer_count,
batch_count, batch_count * 2,
self.image_token_count, self.image_token_count,
self.attention_head_count, self.embed_count
self.embed_count // self.attention_head_count
) )
initial_state = SampleState( initial_state = SampleState(
prev_token = self.start_token, prev_token = self.start_token,
prng_key = prng_key, prng_key = prng_key,
keys_state = jnp.zeros(state_shape), attention_state = jnp.zeros(attention_state_shape)
values_state = jnp.zeros(state_shape)
) )
_, image_tokens = lax.scan( _, image_tokens = lax.scan(

View File

@ -23,22 +23,22 @@ class DecoderSelfAttentionTorch(AttentionTorch):
def forward( def forward(
self, self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
keys_values: FloatTensor, attention_state: FloatTensor,
attention_mask: BoolTensor, attention_mask: BoolTensor,
token_mask: BoolTensor token_mask: BoolTensor
) -> Tuple[FloatTensor, FloatTensor]: ) -> Tuple[FloatTensor, FloatTensor]:
batch_count = decoder_state.shape[0]
keys = self.k_proj.forward(decoder_state) keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state) values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state) queries = self.q_proj.forward(decoder_state)
keys_values = torch.where( attention_state = torch.where(
token_mask[None, :, None], token_mask[None, :, None],
torch.cat([keys, values]), torch.cat([keys, values]),
keys_values attention_state
) )
keys, values = keys_values[:batch_count], keys_values[batch_count:] batch_count = decoder_state.shape[0]
keys, values = attention_state[:batch_count], attention_state[batch_count:]
decoder_state = super().forward(keys, values, queries, attention_mask) decoder_state = super().forward(keys, values, queries, attention_mask)
return decoder_state, keys_values return decoder_state, attention_state
class DecoderLayerTorch(nn.Module): class DecoderLayerTorch(nn.Module):
@ -67,7 +67,7 @@ class DecoderLayerTorch(nn.Module):
self, self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
encoder_state: FloatTensor, encoder_state: FloatTensor,
keys_values_state: FloatTensor, attention_state: FloatTensor,
attention_mask: BoolTensor, attention_mask: BoolTensor,
token_index: LongTensor token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]: ) -> Tuple[FloatTensor, FloatTensor]:
@ -77,9 +77,9 @@ class DecoderLayerTorch(nn.Module):
self_attn_mask = self.token_indices < token_index + 1 self_attn_mask = self.token_indices < token_index + 1
token_mask = self.token_indices == token_index token_mask = self.token_indices == token_index
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0]) self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
decoder_state, keys_values_state = self.self_attn.forward( decoder_state, attention_state = self.self_attn.forward(
decoder_state, decoder_state,
keys_values_state, attention_state,
self_attn_mask, self_attn_mask,
token_mask token_mask
) )
@ -102,7 +102,7 @@ class DecoderLayerTorch(nn.Module):
decoder_state = self.glu.forward(decoder_state) decoder_state = self.glu.forward(decoder_state)
decoder_state = residual + decoder_state decoder_state = residual + decoder_state
return decoder_state, keys_values_state return decoder_state, attention_state
class DalleBartDecoderTorch(nn.Module): class DalleBartDecoderTorch(nn.Module):
@ -139,8 +139,9 @@ class DalleBartDecoderTorch(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_count) self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count) self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False) self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
self.keys_values_state_shape = ( self.attention_state_shape = (
layer_count * 2 * batch_count, layer_count,
2 * batch_count,
image_token_count, image_token_count,
embed_count embed_count
) )
@ -157,7 +158,7 @@ class DalleBartDecoderTorch(nn.Module):
self, self,
text_tokens: LongTensor, text_tokens: LongTensor,
encoder_state: FloatTensor, encoder_state: FloatTensor,
keys_values_state: FloatTensor, attention_state: FloatTensor,
prev_token_and_index: LongTensor prev_token_and_index: LongTensor
) -> Tuple[LongTensor, FloatTensor]: ) -> Tuple[LongTensor, FloatTensor]:
attention_mask = text_tokens.not_equal(1) attention_mask = text_tokens.not_equal(1)
@ -168,17 +169,16 @@ class DalleBartDecoderTorch(nn.Module):
decoder_state += self.embed_positions.forward(token_index) decoder_state += self.embed_positions.forward(token_index)
decoder_state = self.layernorm_embedding.forward(decoder_state) decoder_state = self.layernorm_embedding.forward(decoder_state)
decoder_state = decoder_state[:, None] decoder_state = decoder_state[:, None]
keys_values = [] attention_states_new = []
for i, layer in enumerate(self.layers): for i in range(self.layer_count):
j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count decoder_state, attention_state_layer = self.layers[i].forward(
decoder_state, keys_values_layer = layer.forward(
decoder_state, decoder_state,
encoder_state, encoder_state,
keys_values_state[j1:j2], attention_state[i],
attention_mask, attention_mask,
token_index[:1] token_index[:1]
) )
keys_values.append(keys_values_layer) attention_states_new.append(attention_state_layer)
decoder_state = self.final_ln(decoder_state) decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state) logits = self.lm_head(decoder_state)
a = self.condition_factor a = self.condition_factor
@ -190,7 +190,7 @@ class DalleBartDecoderTorch(nn.Module):
self.zero_prob, self.zero_prob,
torch.exp(logits - top_logits[0]) torch.exp(logits - top_logits[0])
) )
return probs, torch.cat(keys_values) return probs, torch.stack(attention_states_new)
def forward( def forward(
@ -199,17 +199,17 @@ class DalleBartDecoderTorch(nn.Module):
encoder_state: FloatTensor encoder_state: FloatTensor
) -> LongTensor: ) -> LongTensor:
image_tokens: List[LongTensor] = [] image_tokens: List[LongTensor] = []
keys_values_state = torch.zeros(self.keys_values_state_shape) attention_state = torch.zeros(self.attention_state_shape)
if torch.cuda.is_available(): if torch.cuda.is_available():
keys_values_state = keys_values_state.cuda() attention_state = attention_state.cuda()
image_token = self.start_token image_token = self.start_token
for i in range(self.sample_token_count): for i in range(self.sample_token_count):
token_index = self.token_indices[i:i+1] token_index = self.token_indices[i:i+1]
probs, keys_values_state = self.decode_step( probs, attention_state = self.decode_step(
text_tokens = text_tokens, text_tokens = text_tokens,
encoder_state = encoder_state, encoder_state = encoder_state,
keys_values_state = keys_values_state, attention_state = attention_state,
prev_token_and_index = torch.cat([image_token, token_index]) prev_token_and_index = torch.cat([image_token, token_index])
) )

View File

@ -41,6 +41,10 @@ class AttentionFlax(nn.Module):
queries: jnp.ndarray, queries: jnp.ndarray,
attention_mask: jnp.ndarray attention_mask: jnp.ndarray
) -> jnp.ndarray: ) -> jnp.ndarray:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias: jnp.ndarray = lax.select( attention_bias: jnp.ndarray = lax.select(
attention_mask, attention_mask,
jnp.full(attention_mask.shape, 0.0), jnp.full(attention_mask.shape, 0.0),
@ -70,11 +74,9 @@ class EncoderSelfAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray attention_mask: jnp.ndarray
) -> jnp.ndarray: ) -> jnp.ndarray:
shape_split = encoder_state.shape[:2] + (self.head_count, -1) keys = self.k_proj(encoder_state)
keys = self.k_proj(encoder_state).reshape(shape_split) values = self.v_proj(encoder_state)
values = self.v_proj(encoder_state).reshape(shape_split) queries = self.q_proj(encoder_state)
queries = self.q_proj(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
return self.forward(keys, values, queries, attention_mask) return self.forward(keys, values, queries, attention_mask)