min-dalle-test/min_dalle/models/dalle_bart_encoder_flax.py

143 lines
5.0 KiB
Python
Raw Normal View History

2022-06-27 15:57:56 +00:00
from functools import partial
import jax
from jax import lax, numpy as jnp
from flax import linen as nn
class GLUFlax(nn.Module):
count_in_out: int
count_middle: int
def setup(self):
self.gelu = partial(nn.gelu, approximate=False)
self.ln0 = nn.LayerNorm(use_scale=False)
self.ln1 = nn.LayerNorm(use_scale=False)
self.fc0 = nn.Dense(self.count_middle, use_bias=False)
self.fc1 = nn.Dense(self.count_middle, use_bias=False)
self.fc2 = nn.Dense(self.count_in_out, use_bias=False)
@nn.compact
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
z = self.ln0(z)
z = self.ln1(self.gelu(self.fc0(z)) * self.fc1(z))
z = self.fc2(z)
return z
class AttentionFlax(nn.Module):
head_count: int
embed_count: int
def setup(self):
self.q_proj = nn.Dense(self.embed_count, use_bias=False)
self.k_proj = nn.Dense(self.embed_count, use_bias=False)
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
def forward(self,
keys: jnp.ndarray,
values: jnp.ndarray,
queries: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
attention_bias: jnp.ndarray = lax.select(
attention_mask,
jnp.full(attention_mask.shape, 0.0),
jnp.full(attention_mask.shape, -jnp.inf),
)
attention_weights: jnp.ndarray = jnp.einsum(
'bqhd,bkhd->bhqk',
queries,
keys
)
attention_weights += attention_bias[:, None, None, :]
attention_weights = jax.nn.softmax(attention_weights)
attention_output: jnp.ndarray = jnp.einsum(
"bhqk,bkhd->bqhd",
attention_weights,
values
)
shape = attention_output.shape[:2] + (self.embed_count,)
attention_output = attention_output.reshape(shape)
attention_output = self.out_proj(attention_output)
return attention_output
class EncoderSelfAttentionFlax(AttentionFlax):
def __call__(
self,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
keys = self.k_proj(encoder_state).reshape(shape_split)
values = self.v_proj(encoder_state).reshape(shape_split)
queries = self.q_proj(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
return self.forward(keys, values, queries, attention_mask)
class DalleBartEncoderLayerFlax(nn.Module):
attention_head_count: int
embed_count: int
glu_embed_count: int
def setup(self):
self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False)
self.self_attn = EncoderSelfAttentionFlax(
self.attention_head_count,
self.embed_count
)
self.self_attn_layer_norm = nn.LayerNorm()
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
residual = encoder_state
encoder_state = self.pre_self_attn_layer_norm(encoder_state)
encoder_state = self.self_attn(encoder_state, attention_mask)
encoder_state = self.self_attn_layer_norm(encoder_state)
encoder_state = residual + encoder_state
residual = encoder_state
encoder_state = self.glu(encoder_state)
encoder_state = residual + encoder_state
return encoder_state, None
class DalleBartEncoderFlax(nn.Module):
attention_head_count: int
embed_count: int
glu_embed_count: int
text_token_count: int
text_vocab_count: int
layer_count: int
def setup(self):
self.embed_tokens = nn.Embed(self.text_vocab_count, self.embed_count)
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
self.layers = nn.scan(
DalleBartEncoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 },
split_rngs = { "params": True },
in_axes = nn.broadcast,
length = self.layer_count
)(
self.attention_head_count,
self.embed_count,
self.glu_embed_count,
name="FlaxBartEncoderLayers"
)
self.layernorm_embedding = nn.LayerNorm()
self.final_ln = nn.LayerNorm(use_scale=False)
def __call__(self, text_tokens: jnp.ndarray) -> jnp.ndarray:
batch_count, token_count = text_tokens.shape
pose_tokens = jnp.tile(jnp.arange(token_count), (batch_count, 1))
attention_mask = jnp.not_equal(text_tokens, 1)
encoder_state = (
self.embed_tokens(text_tokens) +
self.embed_positions(pose_tokens)
)
encoder_state = self.layernorm_embedding(encoder_state)
encoder_state, _ = self.layers(encoder_state, attention_mask)
encoder_state = self.final_ln(encoder_state)
return encoder_state