You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
5.0 KiB

from functools import partial
import jax
from jax import lax, numpy as jnp
from flax import linen as nn
class GLUFlax(nn.Module):
count_in_out: int
count_middle: int
def setup(self):
self.gelu = partial(nn.gelu, approximate=False)
self.ln0 = nn.LayerNorm(use_scale=False)
self.ln1 = nn.LayerNorm(use_scale=False)
self.fc0 = nn.Dense(self.count_middle, use_bias=False)
self.fc1 = nn.Dense(self.count_middle, use_bias=False)
self.fc2 = nn.Dense(self.count_in_out, use_bias=False)
@nn.compact
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
z = self.ln0(z)
z = self.ln1(self.gelu(self.fc0(z)) * self.fc1(z))
z = self.fc2(z)
return z
class AttentionFlax(nn.Module):
head_count: int
embed_count: int
def setup(self):
self.q_proj = nn.Dense(self.embed_count, use_bias=False)
self.k_proj = nn.Dense(self.embed_count, use_bias=False)
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
def forward(self,
keys: jnp.ndarray,
values: jnp.ndarray,
queries: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
attention_bias: jnp.ndarray = lax.select(
attention_mask,
jnp.full(attention_mask.shape, 0.0),
jnp.full(attention_mask.shape, -jnp.inf),
)
attention_weights: jnp.ndarray = jnp.einsum(
'bqhd,bkhd->bhqk',
queries,
keys
)
attention_weights += attention_bias[:, None, None, :]
attention_weights = jax.nn.softmax(attention_weights)
attention_output: jnp.ndarray = jnp.einsum(
"bhqk,bkhd->bqhd",
attention_weights,
values
)
shape = attention_output.shape[:2] + (self.embed_count,)
attention_output = attention_output.reshape(shape)
attention_output = self.out_proj(attention_output)
return attention_output
class EncoderSelfAttentionFlax(AttentionFlax):
def __call__(
self,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
keys = self.k_proj(encoder_state).reshape(shape_split)
values = self.v_proj(encoder_state).reshape(shape_split)
queries = self.q_proj(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
return self.forward(keys, values, queries, attention_mask)
class DalleBartEncoderLayerFlax(nn.Module):
attention_head_count: int
embed_count: int
glu_embed_count: int
def setup(self):
self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False)
self.self_attn = EncoderSelfAttentionFlax(
self.attention_head_count,
self.embed_count
)
self.self_attn_layer_norm = nn.LayerNorm()
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
residual = encoder_state
encoder_state = self.pre_self_attn_layer_norm(encoder_state)
encoder_state = self.self_attn(encoder_state, attention_mask)
encoder_state = self.self_attn_layer_norm(encoder_state)
encoder_state = residual + encoder_state
residual = encoder_state
encoder_state = self.glu(encoder_state)
encoder_state = residual + encoder_state
return encoder_state, None
class DalleBartEncoderFlax(nn.Module):
attention_head_count: int
embed_count: int
glu_embed_count: int
text_token_count: int
text_vocab_count: int
layer_count: int
def setup(self):
self.embed_tokens = nn.Embed(self.text_vocab_count, self.embed_count)
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
self.layers = nn.scan(
DalleBartEncoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 },
split_rngs = { "params": True },
in_axes = nn.broadcast,
length = self.layer_count
)(
self.attention_head_count,
self.embed_count,
self.glu_embed_count,
name="FlaxBartEncoderLayers"
)
self.layernorm_embedding = nn.LayerNorm()
self.final_ln = nn.LayerNorm(use_scale=False)
def __call__(self, text_tokens: jnp.ndarray) -> jnp.ndarray:
batch_count, token_count = text_tokens.shape
pose_tokens = jnp.tile(jnp.arange(token_count), (batch_count, 1))
attention_mask = jnp.not_equal(text_tokens, 1)
encoder_state = (
self.embed_tokens(text_tokens) +
self.embed_positions(pose_tokens)
)
encoder_state = self.layernorm_embedding(encoder_state)
encoder_state, _ = self.layers(encoder_state, attention_mask)
encoder_state = self.final_ln(encoder_state)
return encoder_state