simplified flax attention and matched torch attention

This commit is contained in:
Brett Kuprel
2022-06-29 14:56:28 -04:00
parent 61cc99c13c
commit d99828a239
4 changed files with 72 additions and 87 deletions

View File

@@ -13,15 +13,9 @@ class DecoderCrossAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray,
) -> jnp.ndarray:
keys: jnp.ndarray = self.k_proj(encoder_state)
values: jnp.ndarray = self.v_proj(encoder_state)
queries: jnp.ndarray = self.q_proj(decoder_state)
query_shape = queries.shape[:2] + (self.head_count, -1)
key_value_shape = keys.shape[:2] + (self.head_count, -1)
keys = keys.reshape(key_value_shape)
values = values.reshape(key_value_shape)
queries = queries.reshape(query_shape)
queries /= queries.shape[-1] ** 0.5
keys = self.k_proj(encoder_state)
values = self.v_proj(encoder_state)
queries = self.q_proj(decoder_state)
return self.forward(keys, values, queries, attention_mask)
@@ -29,31 +23,29 @@ class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(
self,
decoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
state_index: tuple
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
shape_split = decoder_state.shape[:2] + (self.head_count, -1)
keys_state = lax.dynamic_update_slice(
keys_state,
self.k_proj(decoder_state).reshape(shape_split),
) -> Tuple[jnp.ndarray, jnp.ndarray]:
keys = self.k_proj(decoder_state)
values = self.v_proj(decoder_state)
queries = self.q_proj(decoder_state)
attention_state = lax.dynamic_update_slice(
attention_state,
jnp.concatenate([keys, values]),
state_index
)
values_state = lax.dynamic_update_slice(
values_state,
self.v_proj(decoder_state).reshape(shape_split),
state_index
)
queries = self.q_proj(decoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
batch_count = decoder_state.shape[0]
keys, values = attention_state[:batch_count], attention_state[batch_count:]
decoder_state = self.forward(
keys_state,
values_state,
keys,
values,
queries,
attention_mask
)
return decoder_state, (keys_state, values_state)
return decoder_state, attention_state
class DalleBartDecoderLayerFlax(nn.Module):
@@ -82,11 +74,10 @@ class DalleBartDecoderLayerFlax(nn.Module):
self,
decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
token_index: int
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
) -> Tuple[jnp.ndarray, jnp.ndarray]:
# Self Attention
residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
@@ -94,12 +85,11 @@ class DalleBartDecoderLayerFlax(nn.Module):
jnp.arange(self.image_token_count) < token_index + 1,
(decoder_state.shape[0], 1)
)
decoder_state, keys_values_state = self.self_attn(
decoder_state, attention_state = self.self_attn(
decoder_state,
keys_state,
values_state,
attention_state,
self_attention_mask,
(0, token_index, 0, 0)
(0, token_index, 0)
)
decoder_state = self.self_attn_layer_norm(decoder_state)
decoder_state = residual + decoder_state
@@ -120,15 +110,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
decoder_state = self.glu(decoder_state)
decoder_state = residual + decoder_state
return decoder_state, keys_values_state
return decoder_state, attention_state
@flax.struct.dataclass
class SampleState:
prev_token: jnp.ndarray
prng_key: jnp.ndarray
keys_state: jnp.ndarray
values_state: jnp.ndarray
attention_state: jnp.ndarray
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
return a * logits[0, -1] + (1 - a) * logits[1, -1]
@@ -161,8 +150,8 @@ class DalleBartDecoderFlax(nn.Module):
DalleBartDecoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 },
split_rngs = { "params": True },
in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast),
out_axes = (0, 0),
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
out_axes = 0,
length=self.layer_count,
)(
self.image_token_count,
@@ -178,28 +167,26 @@ class DalleBartDecoderFlax(nn.Module):
def __call__(
self,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
prev_token: int,
token_index: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
) -> Tuple[jnp.ndarray, jnp.ndarray]:
batch_count = encoder_state.shape[0]
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
decoder_state = self.embed_tokens(prev_token * ones)
decoder_state += self.embed_positions(token_index * ones)
decoder_state = self.layernorm_embedding(decoder_state)
decoder_state, (keys_state, values_state) = self.layers(
decoder_state, attention_state = self.layers(
decoder_state,
encoder_state,
keys_state,
values_state,
attention_state,
attention_mask,
token_index
)
decoder_state = self.final_ln(decoder_state)
decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state
return decoder_state, attention_state
def sample_image_tokens(
self,
@@ -213,12 +200,11 @@ class DalleBartDecoderFlax(nn.Module):
def sample_next_image_token(
state: SampleState,
token_index: int
) -> Tuple[SampleState, None]:
logits, keys_state, values_state = self.apply(
) -> Tuple[SampleState, jnp.ndarray]:
logits, attention_state = self.apply(
{ 'params': params },
encoder_state = encoder_state,
keys_state = state.keys_state,
values_state = state.values_state,
attention_state = state.attention_state,
attention_mask = attention_mask,
prev_token = state.prev_token,
token_index = token_index
@@ -233,26 +219,23 @@ class DalleBartDecoderFlax(nn.Module):
state = SampleState(
prev_token = next_token,
prng_key = prng_key_next,
keys_state = keys_state,
values_state = values_state
attention_state = attention_state
)
return state, next_token
batch_count = encoder_state.shape[0]
state_shape = (
attention_state_shape = (
self.layer_count,
batch_count,
batch_count * 2,
self.image_token_count,
self.attention_head_count,
self.embed_count // self.attention_head_count
self.embed_count
)
initial_state = SampleState(
prev_token = self.start_token,
prng_key = prng_key,
keys_state = jnp.zeros(state_shape),
values_state = jnp.zeros(state_shape)
attention_state = jnp.zeros(attention_state_shape)
)
_, image_tokens = lax.scan(

View File

@@ -23,22 +23,22 @@ class DecoderSelfAttentionTorch(AttentionTorch):
def forward(
self,
decoder_state: FloatTensor,
keys_values: FloatTensor,
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_mask: BoolTensor
) -> Tuple[FloatTensor, FloatTensor]:
batch_count = decoder_state.shape[0]
keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state)
keys_values = torch.where(
attention_state = torch.where(
token_mask[None, :, None],
torch.cat([keys, values]),
keys_values
attention_state
)
keys, values = keys_values[:batch_count], keys_values[batch_count:]
batch_count = decoder_state.shape[0]
keys, values = attention_state[:batch_count], attention_state[batch_count:]
decoder_state = super().forward(keys, values, queries, attention_mask)
return decoder_state, keys_values
return decoder_state, attention_state
class DecoderLayerTorch(nn.Module):
@@ -67,7 +67,7 @@ class DecoderLayerTorch(nn.Module):
self,
decoder_state: FloatTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]:
@@ -77,9 +77,9 @@ class DecoderLayerTorch(nn.Module):
self_attn_mask = self.token_indices < token_index + 1
token_mask = self.token_indices == token_index
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
decoder_state, keys_values_state = self.self_attn.forward(
decoder_state, attention_state = self.self_attn.forward(
decoder_state,
keys_values_state,
attention_state,
self_attn_mask,
token_mask
)
@@ -102,7 +102,7 @@ class DecoderLayerTorch(nn.Module):
decoder_state = self.glu.forward(decoder_state)
decoder_state = residual + decoder_state
return decoder_state, keys_values_state
return decoder_state, attention_state
class DalleBartDecoderTorch(nn.Module):
@@ -139,8 +139,9 @@ class DalleBartDecoderTorch(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
self.keys_values_state_shape = (
layer_count * 2 * batch_count,
self.attention_state_shape = (
layer_count,
2 * batch_count,
image_token_count,
embed_count
)
@@ -157,7 +158,7 @@ class DalleBartDecoderTorch(nn.Module):
self,
text_tokens: LongTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
attention_state: FloatTensor,
prev_token_and_index: LongTensor
) -> Tuple[LongTensor, FloatTensor]:
attention_mask = text_tokens.not_equal(1)
@@ -168,17 +169,16 @@ class DalleBartDecoderTorch(nn.Module):
decoder_state += self.embed_positions.forward(token_index)
decoder_state = self.layernorm_embedding.forward(decoder_state)
decoder_state = decoder_state[:, None]
keys_values = []
for i, layer in enumerate(self.layers):
j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count
decoder_state, keys_values_layer = layer.forward(
attention_states_new = []
for i in range(self.layer_count):
decoder_state, attention_state_layer = self.layers[i].forward(
decoder_state,
encoder_state,
keys_values_state[j1:j2],
attention_state[i],
attention_mask,
token_index[:1]
)
keys_values.append(keys_values_layer)
attention_states_new.append(attention_state_layer)
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
a = self.condition_factor
@@ -190,7 +190,7 @@ class DalleBartDecoderTorch(nn.Module):
self.zero_prob,
torch.exp(logits - top_logits[0])
)
return probs, torch.cat(keys_values)
return probs, torch.stack(attention_states_new)
def forward(
@@ -199,17 +199,17 @@ class DalleBartDecoderTorch(nn.Module):
encoder_state: FloatTensor
) -> LongTensor:
image_tokens: List[LongTensor] = []
keys_values_state = torch.zeros(self.keys_values_state_shape)
attention_state = torch.zeros(self.attention_state_shape)
if torch.cuda.is_available():
keys_values_state = keys_values_state.cuda()
attention_state = attention_state.cuda()
image_token = self.start_token
for i in range(self.sample_token_count):
token_index = self.token_indices[i:i+1]
probs, keys_values_state = self.decode_step(
probs, attention_state = self.decode_step(
text_tokens = text_tokens,
encoder_state = encoder_state,
keys_values_state = keys_values_state,
attention_state = attention_state,
prev_token_and_index = torch.cat([image_token, token_index])
)

View File

@@ -41,6 +41,10 @@ class AttentionFlax(nn.Module):
queries: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias: jnp.ndarray = lax.select(
attention_mask,
jnp.full(attention_mask.shape, 0.0),
@@ -70,11 +74,9 @@ class EncoderSelfAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
keys = self.k_proj(encoder_state).reshape(shape_split)
values = self.v_proj(encoder_state).reshape(shape_split)
queries = self.q_proj(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
keys = self.k_proj(encoder_state)
values = self.v_proj(encoder_state)
queries = self.q_proj(encoder_state)
return self.forward(keys, values, queries, attention_mask)