simplified flax attention and matched torch attention
This commit is contained in:
parent
61cc99c13c
commit
d99828a239
2
README.md
vendored
2
README.md
vendored
|
@ -4,7 +4,7 @@
|
|||
|
||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).
|
||||
|
||||
DALL·E Mega inference with PyTorch takes 7.3 seconds in Colab to generate an avocado armchair
|
||||
It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab
|
||||
|
||||
### Setup
|
||||
|
||||
|
|
|
@ -13,15 +13,9 @@ class DecoderCrossAttentionFlax(AttentionFlax):
|
|||
encoder_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
keys: jnp.ndarray = self.k_proj(encoder_state)
|
||||
values: jnp.ndarray = self.v_proj(encoder_state)
|
||||
queries: jnp.ndarray = self.q_proj(decoder_state)
|
||||
query_shape = queries.shape[:2] + (self.head_count, -1)
|
||||
key_value_shape = keys.shape[:2] + (self.head_count, -1)
|
||||
keys = keys.reshape(key_value_shape)
|
||||
values = values.reshape(key_value_shape)
|
||||
queries = queries.reshape(query_shape)
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
keys = self.k_proj(encoder_state)
|
||||
values = self.v_proj(encoder_state)
|
||||
queries = self.q_proj(decoder_state)
|
||||
return self.forward(keys, values, queries, attention_mask)
|
||||
|
||||
|
||||
|
@ -29,31 +23,29 @@ class DecoderSelfAttentionFlax(AttentionFlax):
|
|||
def __call__(
|
||||
self,
|
||||
decoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
attention_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray,
|
||||
state_index: tuple
|
||||
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
||||
shape_split = decoder_state.shape[:2] + (self.head_count, -1)
|
||||
keys_state = lax.dynamic_update_slice(
|
||||
keys_state,
|
||||
self.k_proj(decoder_state).reshape(shape_split),
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
keys = self.k_proj(decoder_state)
|
||||
values = self.v_proj(decoder_state)
|
||||
queries = self.q_proj(decoder_state)
|
||||
|
||||
attention_state = lax.dynamic_update_slice(
|
||||
attention_state,
|
||||
jnp.concatenate([keys, values]),
|
||||
state_index
|
||||
)
|
||||
values_state = lax.dynamic_update_slice(
|
||||
values_state,
|
||||
self.v_proj(decoder_state).reshape(shape_split),
|
||||
state_index
|
||||
)
|
||||
queries = self.q_proj(decoder_state).reshape(shape_split)
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
batch_count = decoder_state.shape[0]
|
||||
keys, values = attention_state[:batch_count], attention_state[batch_count:]
|
||||
|
||||
decoder_state = self.forward(
|
||||
keys_state,
|
||||
values_state,
|
||||
keys,
|
||||
values,
|
||||
queries,
|
||||
attention_mask
|
||||
)
|
||||
return decoder_state, (keys_state, values_state)
|
||||
return decoder_state, attention_state
|
||||
|
||||
|
||||
class DalleBartDecoderLayerFlax(nn.Module):
|
||||
|
@ -82,11 +74,10 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
|||
self,
|
||||
decoder_state: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
attention_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray,
|
||||
token_index: int
|
||||
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
# Self Attention
|
||||
residual = decoder_state
|
||||
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
|
||||
|
@ -94,12 +85,11 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
|||
jnp.arange(self.image_token_count) < token_index + 1,
|
||||
(decoder_state.shape[0], 1)
|
||||
)
|
||||
decoder_state, keys_values_state = self.self_attn(
|
||||
decoder_state, attention_state = self.self_attn(
|
||||
decoder_state,
|
||||
keys_state,
|
||||
values_state,
|
||||
attention_state,
|
||||
self_attention_mask,
|
||||
(0, token_index, 0, 0)
|
||||
(0, token_index, 0)
|
||||
)
|
||||
decoder_state = self.self_attn_layer_norm(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
|
@ -120,15 +110,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
|||
decoder_state = self.glu(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
|
||||
return decoder_state, keys_values_state
|
||||
return decoder_state, attention_state
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class SampleState:
|
||||
prev_token: jnp.ndarray
|
||||
prng_key: jnp.ndarray
|
||||
keys_state: jnp.ndarray
|
||||
values_state: jnp.ndarray
|
||||
attention_state: jnp.ndarray
|
||||
|
||||
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
||||
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||
|
@ -161,8 +150,8 @@ class DalleBartDecoderFlax(nn.Module):
|
|||
DalleBartDecoderLayerFlax,
|
||||
variable_axes = { "params": 0, "cache": 0 },
|
||||
split_rngs = { "params": True },
|
||||
in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast),
|
||||
out_axes = (0, 0),
|
||||
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
|
||||
out_axes = 0,
|
||||
length=self.layer_count,
|
||||
)(
|
||||
self.image_token_count,
|
||||
|
@ -178,28 +167,26 @@ class DalleBartDecoderFlax(nn.Module):
|
|||
def __call__(
|
||||
self,
|
||||
encoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
attention_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray,
|
||||
prev_token: int,
|
||||
token_index: int
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
batch_count = encoder_state.shape[0]
|
||||
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
|
||||
decoder_state = self.embed_tokens(prev_token * ones)
|
||||
decoder_state += self.embed_positions(token_index * ones)
|
||||
decoder_state = self.layernorm_embedding(decoder_state)
|
||||
decoder_state, (keys_state, values_state) = self.layers(
|
||||
decoder_state, attention_state = self.layers(
|
||||
decoder_state,
|
||||
encoder_state,
|
||||
keys_state,
|
||||
values_state,
|
||||
attention_state,
|
||||
attention_mask,
|
||||
token_index
|
||||
)
|
||||
decoder_state = self.final_ln(decoder_state)
|
||||
decoder_state = self.lm_head(decoder_state)
|
||||
return decoder_state, keys_state, values_state
|
||||
return decoder_state, attention_state
|
||||
|
||||
def sample_image_tokens(
|
||||
self,
|
||||
|
@ -213,12 +200,11 @@ class DalleBartDecoderFlax(nn.Module):
|
|||
def sample_next_image_token(
|
||||
state: SampleState,
|
||||
token_index: int
|
||||
) -> Tuple[SampleState, None]:
|
||||
logits, keys_state, values_state = self.apply(
|
||||
) -> Tuple[SampleState, jnp.ndarray]:
|
||||
logits, attention_state = self.apply(
|
||||
{ 'params': params },
|
||||
encoder_state = encoder_state,
|
||||
keys_state = state.keys_state,
|
||||
values_state = state.values_state,
|
||||
attention_state = state.attention_state,
|
||||
attention_mask = attention_mask,
|
||||
prev_token = state.prev_token,
|
||||
token_index = token_index
|
||||
|
@ -233,26 +219,23 @@ class DalleBartDecoderFlax(nn.Module):
|
|||
state = SampleState(
|
||||
prev_token = next_token,
|
||||
prng_key = prng_key_next,
|
||||
keys_state = keys_state,
|
||||
values_state = values_state
|
||||
attention_state = attention_state
|
||||
)
|
||||
|
||||
return state, next_token
|
||||
|
||||
batch_count = encoder_state.shape[0]
|
||||
state_shape = (
|
||||
attention_state_shape = (
|
||||
self.layer_count,
|
||||
batch_count,
|
||||
batch_count * 2,
|
||||
self.image_token_count,
|
||||
self.attention_head_count,
|
||||
self.embed_count // self.attention_head_count
|
||||
self.embed_count
|
||||
)
|
||||
|
||||
initial_state = SampleState(
|
||||
prev_token = self.start_token,
|
||||
prng_key = prng_key,
|
||||
keys_state = jnp.zeros(state_shape),
|
||||
values_state = jnp.zeros(state_shape)
|
||||
attention_state = jnp.zeros(attention_state_shape)
|
||||
)
|
||||
|
||||
_, image_tokens = lax.scan(
|
||||
|
|
|
@ -23,22 +23,22 @@ class DecoderSelfAttentionTorch(AttentionTorch):
|
|||
def forward(
|
||||
self,
|
||||
decoder_state: FloatTensor,
|
||||
keys_values: FloatTensor,
|
||||
attention_state: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
token_mask: BoolTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor]:
|
||||
batch_count = decoder_state.shape[0]
|
||||
keys = self.k_proj.forward(decoder_state)
|
||||
values = self.v_proj.forward(decoder_state)
|
||||
queries = self.q_proj.forward(decoder_state)
|
||||
keys_values = torch.where(
|
||||
attention_state = torch.where(
|
||||
token_mask[None, :, None],
|
||||
torch.cat([keys, values]),
|
||||
keys_values
|
||||
attention_state
|
||||
)
|
||||
keys, values = keys_values[:batch_count], keys_values[batch_count:]
|
||||
batch_count = decoder_state.shape[0]
|
||||
keys, values = attention_state[:batch_count], attention_state[batch_count:]
|
||||
decoder_state = super().forward(keys, values, queries, attention_mask)
|
||||
return decoder_state, keys_values
|
||||
return decoder_state, attention_state
|
||||
|
||||
|
||||
class DecoderLayerTorch(nn.Module):
|
||||
|
@ -67,7 +67,7 @@ class DecoderLayerTorch(nn.Module):
|
|||
self,
|
||||
decoder_state: FloatTensor,
|
||||
encoder_state: FloatTensor,
|
||||
keys_values_state: FloatTensor,
|
||||
attention_state: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
token_index: LongTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor]:
|
||||
|
@ -77,9 +77,9 @@ class DecoderLayerTorch(nn.Module):
|
|||
self_attn_mask = self.token_indices < token_index + 1
|
||||
token_mask = self.token_indices == token_index
|
||||
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
|
||||
decoder_state, keys_values_state = self.self_attn.forward(
|
||||
decoder_state, attention_state = self.self_attn.forward(
|
||||
decoder_state,
|
||||
keys_values_state,
|
||||
attention_state,
|
||||
self_attn_mask,
|
||||
token_mask
|
||||
)
|
||||
|
@ -102,7 +102,7 @@ class DecoderLayerTorch(nn.Module):
|
|||
decoder_state = self.glu.forward(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
|
||||
return decoder_state, keys_values_state
|
||||
return decoder_state, attention_state
|
||||
|
||||
|
||||
class DalleBartDecoderTorch(nn.Module):
|
||||
|
@ -139,8 +139,9 @@ class DalleBartDecoderTorch(nn.Module):
|
|||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||
self.final_ln = nn.LayerNorm(embed_count)
|
||||
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
|
||||
self.keys_values_state_shape = (
|
||||
layer_count * 2 * batch_count,
|
||||
self.attention_state_shape = (
|
||||
layer_count,
|
||||
2 * batch_count,
|
||||
image_token_count,
|
||||
embed_count
|
||||
)
|
||||
|
@ -157,7 +158,7 @@ class DalleBartDecoderTorch(nn.Module):
|
|||
self,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor,
|
||||
keys_values_state: FloatTensor,
|
||||
attention_state: FloatTensor,
|
||||
prev_token_and_index: LongTensor
|
||||
) -> Tuple[LongTensor, FloatTensor]:
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
|
@ -168,17 +169,16 @@ class DalleBartDecoderTorch(nn.Module):
|
|||
decoder_state += self.embed_positions.forward(token_index)
|
||||
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
||||
decoder_state = decoder_state[:, None]
|
||||
keys_values = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count
|
||||
decoder_state, keys_values_layer = layer.forward(
|
||||
attention_states_new = []
|
||||
for i in range(self.layer_count):
|
||||
decoder_state, attention_state_layer = self.layers[i].forward(
|
||||
decoder_state,
|
||||
encoder_state,
|
||||
keys_values_state[j1:j2],
|
||||
attention_state[i],
|
||||
attention_mask,
|
||||
token_index[:1]
|
||||
)
|
||||
keys_values.append(keys_values_layer)
|
||||
attention_states_new.append(attention_state_layer)
|
||||
decoder_state = self.final_ln(decoder_state)
|
||||
logits = self.lm_head(decoder_state)
|
||||
a = self.condition_factor
|
||||
|
@ -190,7 +190,7 @@ class DalleBartDecoderTorch(nn.Module):
|
|||
self.zero_prob,
|
||||
torch.exp(logits - top_logits[0])
|
||||
)
|
||||
return probs, torch.cat(keys_values)
|
||||
return probs, torch.stack(attention_states_new)
|
||||
|
||||
|
||||
def forward(
|
||||
|
@ -199,17 +199,17 @@ class DalleBartDecoderTorch(nn.Module):
|
|||
encoder_state: FloatTensor
|
||||
) -> LongTensor:
|
||||
image_tokens: List[LongTensor] = []
|
||||
keys_values_state = torch.zeros(self.keys_values_state_shape)
|
||||
attention_state = torch.zeros(self.attention_state_shape)
|
||||
if torch.cuda.is_available():
|
||||
keys_values_state = keys_values_state.cuda()
|
||||
attention_state = attention_state.cuda()
|
||||
image_token = self.start_token
|
||||
|
||||
for i in range(self.sample_token_count):
|
||||
token_index = self.token_indices[i:i+1]
|
||||
probs, keys_values_state = self.decode_step(
|
||||
probs, attention_state = self.decode_step(
|
||||
text_tokens = text_tokens,
|
||||
encoder_state = encoder_state,
|
||||
keys_values_state = keys_values_state,
|
||||
attention_state = attention_state,
|
||||
prev_token_and_index = torch.cat([image_token, token_index])
|
||||
)
|
||||
|
||||
|
|
|
@ -41,6 +41,10 @@ class AttentionFlax(nn.Module):
|
|||
queries: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray
|
||||
) -> jnp.ndarray:
|
||||
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
|
||||
values = values.reshape(values.shape[:2] + (self.head_count, -1))
|
||||
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
attention_bias: jnp.ndarray = lax.select(
|
||||
attention_mask,
|
||||
jnp.full(attention_mask.shape, 0.0),
|
||||
|
@ -70,11 +74,9 @@ class EncoderSelfAttentionFlax(AttentionFlax):
|
|||
encoder_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray
|
||||
) -> jnp.ndarray:
|
||||
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
|
||||
keys = self.k_proj(encoder_state).reshape(shape_split)
|
||||
values = self.v_proj(encoder_state).reshape(shape_split)
|
||||
queries = self.q_proj(encoder_state).reshape(shape_split)
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
keys = self.k_proj(encoder_state)
|
||||
values = self.v_proj(encoder_state)
|
||||
queries = self.q_proj(encoder_state)
|
||||
return self.forward(keys, values, queries, attention_mask)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user