From d99828a2397761930eae763d9fa8df71b438c335 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Wed, 29 Jun 2022 14:56:28 -0400 Subject: [PATCH] simplified flax attention and matched torch attention --- README.md | 2 +- min_dalle/models/dalle_bart_decoder_flax.py | 97 ++++++++------------ min_dalle/models/dalle_bart_decoder_torch.py | 48 +++++----- min_dalle/models/dalle_bart_encoder_flax.py | 12 ++- 4 files changed, 72 insertions(+), 87 deletions(-) diff --git a/README.md b/README.md index 0f84204..d735762 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models). -DALL·E Mega inference with PyTorch takes 7.3 seconds in Colab to generate an avocado armchair +It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab ### Setup diff --git a/min_dalle/models/dalle_bart_decoder_flax.py b/min_dalle/models/dalle_bart_decoder_flax.py index fa2d457..704201d 100644 --- a/min_dalle/models/dalle_bart_decoder_flax.py +++ b/min_dalle/models/dalle_bart_decoder_flax.py @@ -13,15 +13,9 @@ class DecoderCrossAttentionFlax(AttentionFlax): encoder_state: jnp.ndarray, attention_mask: jnp.ndarray, ) -> jnp.ndarray: - keys: jnp.ndarray = self.k_proj(encoder_state) - values: jnp.ndarray = self.v_proj(encoder_state) - queries: jnp.ndarray = self.q_proj(decoder_state) - query_shape = queries.shape[:2] + (self.head_count, -1) - key_value_shape = keys.shape[:2] + (self.head_count, -1) - keys = keys.reshape(key_value_shape) - values = values.reshape(key_value_shape) - queries = queries.reshape(query_shape) - queries /= queries.shape[-1] ** 0.5 + keys = self.k_proj(encoder_state) + values = self.v_proj(encoder_state) + queries = self.q_proj(decoder_state) return self.forward(keys, values, queries, attention_mask) @@ -29,31 +23,29 @@ class DecoderSelfAttentionFlax(AttentionFlax): def __call__( self, decoder_state: jnp.ndarray, - keys_state: jnp.ndarray, - values_state: jnp.ndarray, + attention_state: jnp.ndarray, attention_mask: jnp.ndarray, state_index: tuple - ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: - shape_split = decoder_state.shape[:2] + (self.head_count, -1) - keys_state = lax.dynamic_update_slice( - keys_state, - self.k_proj(decoder_state).reshape(shape_split), + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + keys = self.k_proj(decoder_state) + values = self.v_proj(decoder_state) + queries = self.q_proj(decoder_state) + + attention_state = lax.dynamic_update_slice( + attention_state, + jnp.concatenate([keys, values]), state_index ) - values_state = lax.dynamic_update_slice( - values_state, - self.v_proj(decoder_state).reshape(shape_split), - state_index - ) - queries = self.q_proj(decoder_state).reshape(shape_split) - queries /= queries.shape[-1] ** 0.5 + batch_count = decoder_state.shape[0] + keys, values = attention_state[:batch_count], attention_state[batch_count:] + decoder_state = self.forward( - keys_state, - values_state, + keys, + values, queries, attention_mask ) - return decoder_state, (keys_state, values_state) + return decoder_state, attention_state class DalleBartDecoderLayerFlax(nn.Module): @@ -82,11 +74,10 @@ class DalleBartDecoderLayerFlax(nn.Module): self, decoder_state: jnp.ndarray, encoder_state: jnp.ndarray, - keys_state: jnp.ndarray, - values_state: jnp.ndarray, + attention_state: jnp.ndarray, attention_mask: jnp.ndarray, token_index: int - ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: + ) -> Tuple[jnp.ndarray, jnp.ndarray]: # Self Attention residual = decoder_state decoder_state = self.pre_self_attn_layer_norm(decoder_state) @@ -94,12 +85,11 @@ class DalleBartDecoderLayerFlax(nn.Module): jnp.arange(self.image_token_count) < token_index + 1, (decoder_state.shape[0], 1) ) - decoder_state, keys_values_state = self.self_attn( + decoder_state, attention_state = self.self_attn( decoder_state, - keys_state, - values_state, + attention_state, self_attention_mask, - (0, token_index, 0, 0) + (0, token_index, 0) ) decoder_state = self.self_attn_layer_norm(decoder_state) decoder_state = residual + decoder_state @@ -120,15 +110,14 @@ class DalleBartDecoderLayerFlax(nn.Module): decoder_state = self.glu(decoder_state) decoder_state = residual + decoder_state - return decoder_state, keys_values_state + return decoder_state, attention_state @flax.struct.dataclass class SampleState: prev_token: jnp.ndarray prng_key: jnp.ndarray - keys_state: jnp.ndarray - values_state: jnp.ndarray + attention_state: jnp.ndarray def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray: return a * logits[0, -1] + (1 - a) * logits[1, -1] @@ -161,8 +150,8 @@ class DalleBartDecoderFlax(nn.Module): DalleBartDecoderLayerFlax, variable_axes = { "params": 0, "cache": 0 }, split_rngs = { "params": True }, - in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast), - out_axes = (0, 0), + in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast), + out_axes = 0, length=self.layer_count, )( self.image_token_count, @@ -178,28 +167,26 @@ class DalleBartDecoderFlax(nn.Module): def __call__( self, encoder_state: jnp.ndarray, - keys_state: jnp.ndarray, - values_state: jnp.ndarray, + attention_state: jnp.ndarray, attention_mask: jnp.ndarray, prev_token: int, token_index: int - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jnp.ndarray, jnp.ndarray]: batch_count = encoder_state.shape[0] ones = jnp.ones((batch_count, 1), dtype=jnp.int32) decoder_state = self.embed_tokens(prev_token * ones) decoder_state += self.embed_positions(token_index * ones) decoder_state = self.layernorm_embedding(decoder_state) - decoder_state, (keys_state, values_state) = self.layers( + decoder_state, attention_state = self.layers( decoder_state, encoder_state, - keys_state, - values_state, + attention_state, attention_mask, token_index ) decoder_state = self.final_ln(decoder_state) decoder_state = self.lm_head(decoder_state) - return decoder_state, keys_state, values_state + return decoder_state, attention_state def sample_image_tokens( self, @@ -213,12 +200,11 @@ class DalleBartDecoderFlax(nn.Module): def sample_next_image_token( state: SampleState, token_index: int - ) -> Tuple[SampleState, None]: - logits, keys_state, values_state = self.apply( + ) -> Tuple[SampleState, jnp.ndarray]: + logits, attention_state = self.apply( { 'params': params }, encoder_state = encoder_state, - keys_state = state.keys_state, - values_state = state.values_state, + attention_state = state.attention_state, attention_mask = attention_mask, prev_token = state.prev_token, token_index = token_index @@ -233,26 +219,23 @@ class DalleBartDecoderFlax(nn.Module): state = SampleState( prev_token = next_token, prng_key = prng_key_next, - keys_state = keys_state, - values_state = values_state + attention_state = attention_state ) return state, next_token batch_count = encoder_state.shape[0] - state_shape = ( + attention_state_shape = ( self.layer_count, - batch_count, + batch_count * 2, self.image_token_count, - self.attention_head_count, - self.embed_count // self.attention_head_count + self.embed_count ) initial_state = SampleState( prev_token = self.start_token, prng_key = prng_key, - keys_state = jnp.zeros(state_shape), - values_state = jnp.zeros(state_shape) + attention_state = jnp.zeros(attention_state_shape) ) _, image_tokens = lax.scan( diff --git a/min_dalle/models/dalle_bart_decoder_torch.py b/min_dalle/models/dalle_bart_decoder_torch.py index 9957f2b..3aab6dd 100644 --- a/min_dalle/models/dalle_bart_decoder_torch.py +++ b/min_dalle/models/dalle_bart_decoder_torch.py @@ -23,22 +23,22 @@ class DecoderSelfAttentionTorch(AttentionTorch): def forward( self, decoder_state: FloatTensor, - keys_values: FloatTensor, + attention_state: FloatTensor, attention_mask: BoolTensor, token_mask: BoolTensor ) -> Tuple[FloatTensor, FloatTensor]: - batch_count = decoder_state.shape[0] keys = self.k_proj.forward(decoder_state) values = self.v_proj.forward(decoder_state) queries = self.q_proj.forward(decoder_state) - keys_values = torch.where( + attention_state = torch.where( token_mask[None, :, None], torch.cat([keys, values]), - keys_values + attention_state ) - keys, values = keys_values[:batch_count], keys_values[batch_count:] + batch_count = decoder_state.shape[0] + keys, values = attention_state[:batch_count], attention_state[batch_count:] decoder_state = super().forward(keys, values, queries, attention_mask) - return decoder_state, keys_values + return decoder_state, attention_state class DecoderLayerTorch(nn.Module): @@ -67,7 +67,7 @@ class DecoderLayerTorch(nn.Module): self, decoder_state: FloatTensor, encoder_state: FloatTensor, - keys_values_state: FloatTensor, + attention_state: FloatTensor, attention_mask: BoolTensor, token_index: LongTensor ) -> Tuple[FloatTensor, FloatTensor]: @@ -77,9 +77,9 @@ class DecoderLayerTorch(nn.Module): self_attn_mask = self.token_indices < token_index + 1 token_mask = self.token_indices == token_index self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0]) - decoder_state, keys_values_state = self.self_attn.forward( + decoder_state, attention_state = self.self_attn.forward( decoder_state, - keys_values_state, + attention_state, self_attn_mask, token_mask ) @@ -102,7 +102,7 @@ class DecoderLayerTorch(nn.Module): decoder_state = self.glu.forward(decoder_state) decoder_state = residual + decoder_state - return decoder_state, keys_values_state + return decoder_state, attention_state class DalleBartDecoderTorch(nn.Module): @@ -139,8 +139,9 @@ class DalleBartDecoderTorch(nn.Module): self.layernorm_embedding = nn.LayerNorm(embed_count) self.final_ln = nn.LayerNorm(embed_count) self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False) - self.keys_values_state_shape = ( - layer_count * 2 * batch_count, + self.attention_state_shape = ( + layer_count, + 2 * batch_count, image_token_count, embed_count ) @@ -157,7 +158,7 @@ class DalleBartDecoderTorch(nn.Module): self, text_tokens: LongTensor, encoder_state: FloatTensor, - keys_values_state: FloatTensor, + attention_state: FloatTensor, prev_token_and_index: LongTensor ) -> Tuple[LongTensor, FloatTensor]: attention_mask = text_tokens.not_equal(1) @@ -168,17 +169,16 @@ class DalleBartDecoderTorch(nn.Module): decoder_state += self.embed_positions.forward(token_index) decoder_state = self.layernorm_embedding.forward(decoder_state) decoder_state = decoder_state[:, None] - keys_values = [] - for i, layer in enumerate(self.layers): - j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count - decoder_state, keys_values_layer = layer.forward( + attention_states_new = [] + for i in range(self.layer_count): + decoder_state, attention_state_layer = self.layers[i].forward( decoder_state, encoder_state, - keys_values_state[j1:j2], + attention_state[i], attention_mask, token_index[:1] ) - keys_values.append(keys_values_layer) + attention_states_new.append(attention_state_layer) decoder_state = self.final_ln(decoder_state) logits = self.lm_head(decoder_state) a = self.condition_factor @@ -190,7 +190,7 @@ class DalleBartDecoderTorch(nn.Module): self.zero_prob, torch.exp(logits - top_logits[0]) ) - return probs, torch.cat(keys_values) + return probs, torch.stack(attention_states_new) def forward( @@ -199,17 +199,17 @@ class DalleBartDecoderTorch(nn.Module): encoder_state: FloatTensor ) -> LongTensor: image_tokens: List[LongTensor] = [] - keys_values_state = torch.zeros(self.keys_values_state_shape) + attention_state = torch.zeros(self.attention_state_shape) if torch.cuda.is_available(): - keys_values_state = keys_values_state.cuda() + attention_state = attention_state.cuda() image_token = self.start_token for i in range(self.sample_token_count): token_index = self.token_indices[i:i+1] - probs, keys_values_state = self.decode_step( + probs, attention_state = self.decode_step( text_tokens = text_tokens, encoder_state = encoder_state, - keys_values_state = keys_values_state, + attention_state = attention_state, prev_token_and_index = torch.cat([image_token, token_index]) ) diff --git a/min_dalle/models/dalle_bart_encoder_flax.py b/min_dalle/models/dalle_bart_encoder_flax.py index 3d159f0..d320e69 100644 --- a/min_dalle/models/dalle_bart_encoder_flax.py +++ b/min_dalle/models/dalle_bart_encoder_flax.py @@ -41,6 +41,10 @@ class AttentionFlax(nn.Module): queries: jnp.ndarray, attention_mask: jnp.ndarray ) -> jnp.ndarray: + keys = keys.reshape(keys.shape[:2] + (self.head_count, -1)) + values = values.reshape(values.shape[:2] + (self.head_count, -1)) + queries = queries.reshape(queries.shape[:2] + (self.head_count, -1)) + queries /= queries.shape[-1] ** 0.5 attention_bias: jnp.ndarray = lax.select( attention_mask, jnp.full(attention_mask.shape, 0.0), @@ -70,11 +74,9 @@ class EncoderSelfAttentionFlax(AttentionFlax): encoder_state: jnp.ndarray, attention_mask: jnp.ndarray ) -> jnp.ndarray: - shape_split = encoder_state.shape[:2] + (self.head_count, -1) - keys = self.k_proj(encoder_state).reshape(shape_split) - values = self.v_proj(encoder_state).reshape(shape_split) - queries = self.q_proj(encoder_state).reshape(shape_split) - queries /= queries.shape[-1] ** 0.5 + keys = self.k_proj(encoder_state) + values = self.v_proj(encoder_state) + queries = self.q_proj(encoder_state) return self.forward(keys, values, queries, attention_mask)