fixed relative imports
This commit is contained in:
0
min_dalle/models/__init__.py
Normal file
0
min_dalle/models/__init__.py
Normal file
288
min_dalle/models/dalle_bart_decoder_flax.py
Normal file
288
min_dalle/models/dalle_bart_decoder_flax.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import jax, flax
|
||||
from jax import lax, numpy as jnp
|
||||
from flax import linen as nn
|
||||
from typing import Tuple
|
||||
|
||||
from .dalle_bart_encoder_flax import GLUFlax, AttentionFlax
|
||||
|
||||
|
||||
class DecoderCrossAttentionFlax(AttentionFlax):
|
||||
def __call__(
|
||||
self,
|
||||
decoder_state: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
keys: jnp.ndarray = self.k_proj(encoder_state)
|
||||
values: jnp.ndarray = self.v_proj(encoder_state)
|
||||
queries: jnp.ndarray = self.q_proj(decoder_state)
|
||||
query_shape = queries.shape[:2] + (self.head_count, -1)
|
||||
key_value_shape = keys.shape[:2] + (self.head_count, -1)
|
||||
keys = keys.reshape(key_value_shape)
|
||||
values = values.reshape(key_value_shape)
|
||||
queries = queries.reshape(query_shape)
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
return self.forward(keys, values, queries, attention_mask)
|
||||
|
||||
|
||||
class DecoderSelfAttentionFlax(AttentionFlax):
|
||||
def __call__(self,
|
||||
decoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray,
|
||||
state_index: tuple
|
||||
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
||||
shape_split = decoder_state.shape[:2] + (self.head_count, -1)
|
||||
keys_state = lax.dynamic_update_slice(
|
||||
keys_state,
|
||||
self.k_proj(decoder_state).reshape(shape_split),
|
||||
state_index
|
||||
)
|
||||
values_state = lax.dynamic_update_slice(
|
||||
values_state,
|
||||
self.v_proj(decoder_state).reshape(shape_split),
|
||||
state_index
|
||||
)
|
||||
queries = self.q_proj(decoder_state).reshape(shape_split)
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
decoder_state = self.forward(
|
||||
keys_state,
|
||||
values_state,
|
||||
queries,
|
||||
attention_mask
|
||||
)
|
||||
return decoder_state, (keys_state, values_state)
|
||||
|
||||
|
||||
class DalleBartDecoderLayerFlax(nn.Module):
|
||||
image_token_count: int
|
||||
attention_head_count: int
|
||||
embed_count: int
|
||||
glu_embed_count: int
|
||||
|
||||
def setup(self):
|
||||
self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False)
|
||||
self.self_attn = DecoderSelfAttentionFlax(
|
||||
self.attention_head_count,
|
||||
self.embed_count
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm()
|
||||
self.pre_encoder_attn_layer_norm = nn.LayerNorm(use_scale=False)
|
||||
self.encoder_attn = DecoderCrossAttentionFlax(
|
||||
self.attention_head_count,
|
||||
self.embed_count,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm()
|
||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self,
|
||||
decoder_state: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray,
|
||||
token_index: int
|
||||
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
||||
# Self Attention
|
||||
residual = decoder_state
|
||||
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
|
||||
self_attention_mask = jnp.tile(
|
||||
jnp.arange(self.image_token_count) < token_index + 1,
|
||||
(decoder_state.shape[0], 1)
|
||||
)
|
||||
decoder_state, keys_values_state = self.self_attn(
|
||||
decoder_state,
|
||||
keys_state,
|
||||
values_state,
|
||||
self_attention_mask,
|
||||
(0, token_index, 0, 0)
|
||||
)
|
||||
decoder_state = self.self_attn_layer_norm(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
|
||||
# Cross Attention
|
||||
residual = decoder_state
|
||||
decoder_state = self.pre_encoder_attn_layer_norm(decoder_state)
|
||||
decoder_state = self.encoder_attn(
|
||||
decoder_state,
|
||||
encoder_state,
|
||||
attention_mask
|
||||
)
|
||||
decoder_state = self.encoder_attn_layer_norm(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
|
||||
# Feed forward
|
||||
residual = decoder_state
|
||||
decoder_state = self.glu(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
|
||||
return decoder_state, keys_values_state
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class SampleState:
|
||||
prev_token: jnp.ndarray
|
||||
prng_key: jnp.ndarray
|
||||
keys_state: jnp.ndarray
|
||||
values_state: jnp.ndarray
|
||||
|
||||
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
||||
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||
|
||||
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
|
||||
top_logits, top_tokens = lax.top_k(logits, k)
|
||||
suppressed = -jnp.inf * jnp.ones_like(logits)
|
||||
return lax.select(logits < top_logits[-1], suppressed, logits)
|
||||
|
||||
class DalleBartDecoderFlax(nn.Module):
|
||||
image_token_count: int
|
||||
text_token_count: int
|
||||
image_vocab_count: int
|
||||
attention_head_count: int
|
||||
embed_count: int
|
||||
glu_embed_count: int
|
||||
layer_count: int
|
||||
start_token: int
|
||||
|
||||
def setup(self):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.image_vocab_count + 1,
|
||||
self.embed_count
|
||||
)
|
||||
self.embed_positions = nn.Embed(
|
||||
self.image_token_count,
|
||||
self.embed_count
|
||||
)
|
||||
self.layers = nn.scan(
|
||||
DalleBartDecoderLayerFlax,
|
||||
variable_axes = { "params": 0, "cache": 0 },
|
||||
split_rngs = { "params": True },
|
||||
in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast),
|
||||
out_axes = (0, 0),
|
||||
length=self.layer_count,
|
||||
)(
|
||||
self.image_token_count,
|
||||
self.attention_head_count,
|
||||
self.embed_count,
|
||||
self.glu_embed_count,
|
||||
name="FlaxBartDecoderLayers"
|
||||
)
|
||||
self.layernorm_embedding = nn.LayerNorm()
|
||||
self.final_ln = nn.LayerNorm(use_scale=False)
|
||||
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
|
||||
|
||||
def __call__(self,
|
||||
encoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray,
|
||||
prev_token: int,
|
||||
token_index: int
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
||||
batch_count = encoder_state.shape[0]
|
||||
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
|
||||
decoder_state = self.embed_tokens(prev_token * ones)
|
||||
decoder_state += self.embed_positions(token_index * ones)
|
||||
decoder_state = self.layernorm_embedding(decoder_state)
|
||||
decoder_state, (keys_state, values_state) = self.layers(
|
||||
decoder_state,
|
||||
encoder_state,
|
||||
keys_state,
|
||||
values_state,
|
||||
attention_mask,
|
||||
token_index
|
||||
)
|
||||
decoder_state = self.final_ln(decoder_state)
|
||||
decoder_state = self.lm_head(decoder_state)
|
||||
return decoder_state, keys_state, values_state
|
||||
|
||||
def compute_logits(self,
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
params: dict
|
||||
) -> jnp.ndarray:
|
||||
batch_count = encoder_state.shape[0]
|
||||
state_shape = (
|
||||
self.layer_count,
|
||||
batch_count,
|
||||
self.image_token_count,
|
||||
self.attention_head_count,
|
||||
self.embed_count // self.attention_head_count
|
||||
)
|
||||
keys_state = jnp.zeros(state_shape)
|
||||
values_state = jnp.zeros(state_shape)
|
||||
|
||||
logits, _, _ = self.apply(
|
||||
{ 'params': params },
|
||||
encoder_state = encoder_state,
|
||||
keys_state = keys_state,
|
||||
values_state = values_state,
|
||||
attention_mask = jnp.not_equal(text_tokens, 1),
|
||||
prev_token = self.start_token,
|
||||
token_index = 0
|
||||
)
|
||||
|
||||
return super_conditioned(logits, 10.0)
|
||||
|
||||
def sample_image_tokens(self,
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
prng_key: jax.random.PRNGKey,
|
||||
params: dict
|
||||
) -> jnp.ndarray:
|
||||
attention_mask = jnp.not_equal(text_tokens, 1)
|
||||
|
||||
def sample_next_image_token(
|
||||
state: SampleState,
|
||||
token_index: int
|
||||
) -> Tuple[SampleState, None]:
|
||||
logits, keys_state, values_state = self.apply(
|
||||
{ 'params': params },
|
||||
encoder_state = encoder_state,
|
||||
keys_state = state.keys_state,
|
||||
values_state = state.values_state,
|
||||
attention_mask = attention_mask,
|
||||
prev_token = state.prev_token,
|
||||
token_index = token_index
|
||||
)
|
||||
|
||||
logits = super_conditioned(logits, 10.0)
|
||||
logits = keep_top_k(logits, k=50)
|
||||
|
||||
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
||||
next_token = jax.random.categorical(prng_key, logits, axis=-1)
|
||||
|
||||
state = SampleState(
|
||||
prev_token = next_token,
|
||||
prng_key = prng_key_next,
|
||||
keys_state = keys_state,
|
||||
values_state = values_state
|
||||
)
|
||||
|
||||
return state, next_token
|
||||
|
||||
batch_count = encoder_state.shape[0]
|
||||
state_shape = (
|
||||
self.layer_count,
|
||||
batch_count,
|
||||
self.image_token_count,
|
||||
self.attention_head_count,
|
||||
self.embed_count // self.attention_head_count
|
||||
)
|
||||
|
||||
initial_state = SampleState(
|
||||
prev_token = self.start_token,
|
||||
prng_key = prng_key,
|
||||
keys_state = jnp.zeros(state_shape),
|
||||
values_state = jnp.zeros(state_shape)
|
||||
)
|
||||
|
||||
_, image_tokens = lax.scan(
|
||||
sample_next_image_token,
|
||||
initial_state,
|
||||
jnp.arange(self.image_token_count)
|
||||
)
|
||||
|
||||
return image_tokens
|
214
min_dalle/models/dalle_bart_decoder_torch.py
Normal file
214
min_dalle/models/dalle_bart_decoder_torch.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import torch
|
||||
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
||||
from typing import List, Tuple
|
||||
|
||||
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
|
||||
|
||||
|
||||
class DecoderCrossAttentionTorch(AttentionTorch):
|
||||
def forward(
|
||||
self,
|
||||
decoder_state: FloatTensor,
|
||||
encoder_state: FloatTensor,
|
||||
attention_mask: BoolTensor
|
||||
) -> FloatTensor:
|
||||
keys = self.k_proj.forward(encoder_state)
|
||||
values = self.v_proj.forward(encoder_state)
|
||||
queries = self.q_proj.forward(decoder_state)
|
||||
return super().forward(keys, values, queries, attention_mask)
|
||||
|
||||
|
||||
class DecoderSelfAttentionTorch(AttentionTorch):
|
||||
def forward(self,
|
||||
decoder_state: FloatTensor,
|
||||
keys_values: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
token_index: LongTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor, FloatTensor]:
|
||||
keys = self.k_proj.forward(decoder_state)
|
||||
values = self.v_proj.forward(decoder_state)
|
||||
queries = self.q_proj.forward(decoder_state)
|
||||
|
||||
batch_count = decoder_state.shape[0]
|
||||
token_count = keys_values.shape[-1]
|
||||
keys_values = torch.where(
|
||||
(torch.arange(token_count) == token_index)[None, None, :],
|
||||
torch.cat([keys, values]).squeeze(2),
|
||||
keys_values
|
||||
)
|
||||
keys, values = keys_values[:batch_count, :, None], keys_values[batch_count:, :, None]
|
||||
|
||||
decoder_state = super().forward(keys, values, queries, attention_mask)
|
||||
return decoder_state, keys_values
|
||||
|
||||
|
||||
class DecoderLayerTorch(nn.Module):
|
||||
def __init__(self,
|
||||
image_token_count: int,
|
||||
head_count: int,
|
||||
embed_count: int,
|
||||
glu_embed_count: int
|
||||
):
|
||||
super().__init__()
|
||||
self.image_token_count = image_token_count
|
||||
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.self_attn = DecoderSelfAttentionTorch(head_count, embed_count)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.encoder_attn = DecoderCrossAttentionTorch(head_count, embed_count)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.glu = GLUTorch(embed_count, glu_embed_count)
|
||||
|
||||
def forward(self,
|
||||
decoder_state: FloatTensor,
|
||||
encoder_state: FloatTensor,
|
||||
keys_values_state: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
token_index: LongTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor]:
|
||||
# Self Attention
|
||||
residual = decoder_state
|
||||
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
|
||||
self_attn_mask = torch.arange(self.image_token_count) < token_index + 1
|
||||
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
|
||||
decoder_state = decoder_state.transpose(1, 2).unsqueeze(2)
|
||||
# print("decoder_state", decoder_state.shape)
|
||||
decoder_state, keys_values_state = self.self_attn.forward(
|
||||
decoder_state,
|
||||
keys_values_state,
|
||||
self_attn_mask,
|
||||
token_index
|
||||
)
|
||||
decoder_state = decoder_state.transpose(1, 3).squeeze(2)
|
||||
decoder_state = self.self_attn_layer_norm.forward(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
|
||||
# Cross Attention
|
||||
residual = decoder_state
|
||||
decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state)
|
||||
decoder_state = decoder_state.transpose(1, 2).unsqueeze(2)
|
||||
decoder_state = self.encoder_attn.forward(
|
||||
decoder_state,
|
||||
encoder_state,
|
||||
attention_mask
|
||||
)
|
||||
decoder_state = decoder_state.transpose(1, 3).squeeze(2)
|
||||
decoder_state = self.encoder_attn_layer_norm.forward(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
|
||||
# Feed forward
|
||||
residual = decoder_state
|
||||
decoder_state = self.glu.forward(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
|
||||
return decoder_state, keys_values_state
|
||||
|
||||
|
||||
class DalleBartDecoderTorch(nn.Module):
|
||||
def __init__(self,
|
||||
image_vocab_size: int,
|
||||
image_token_count: int,
|
||||
sample_token_count: int,
|
||||
embed_count: int,
|
||||
attention_head_count: int,
|
||||
glu_embed_count: int,
|
||||
layer_count: int,
|
||||
batch_count: int,
|
||||
start_token: int,
|
||||
is_verbose: bool
|
||||
):
|
||||
super().__init__()
|
||||
self.is_verbose = is_verbose
|
||||
self.layer_count = layer_count
|
||||
self.sample_token_count = sample_token_count
|
||||
self.start_token = torch.tensor([start_token]).to(torch.long)
|
||||
self.pad_token = torch.tensor([1]).to(torch.long)
|
||||
self.condition_factor = torch.tensor([10]).to(torch.float)
|
||||
self.image_token_count = image_token_count
|
||||
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count)
|
||||
self.embed_positions = nn.Embedding(image_token_count, embed_count)
|
||||
self.layers: List[DecoderLayerTorch] = nn.ModuleList([
|
||||
DecoderLayerTorch(
|
||||
image_token_count,
|
||||
attention_head_count,
|
||||
embed_count,
|
||||
glu_embed_count
|
||||
)
|
||||
for _ in range(layer_count)
|
||||
])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||
self.final_ln = nn.LayerNorm(embed_count)
|
||||
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
|
||||
self.keys_values_state_shape = (
|
||||
layer_count * 2 * batch_count,
|
||||
embed_count,
|
||||
image_token_count
|
||||
)
|
||||
|
||||
|
||||
def decode_step(self,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor,
|
||||
keys_values_state: FloatTensor,
|
||||
prev_token_and_index: LongTensor
|
||||
) -> Tuple[LongTensor, FloatTensor]:
|
||||
attention_mask = text_tokens.not_equal(self.pad_token)
|
||||
batch_count = encoder_state.shape[0]
|
||||
prev_token = torch.cat([prev_token_and_index[:1]] * batch_count)
|
||||
token_index = torch.cat([prev_token_and_index[1:]] * batch_count)
|
||||
decoder_state = self.embed_tokens.forward(prev_token)
|
||||
decoder_state += self.embed_positions.forward(token_index)
|
||||
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
||||
decoder_state = decoder_state[:, None] # (batch_count, 1, embed_count)
|
||||
keys_values = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count
|
||||
decoder_state, keys_values_layer = layer.forward(
|
||||
decoder_state,
|
||||
encoder_state,
|
||||
keys_values_state[j1:j2],
|
||||
attention_mask,
|
||||
token_index[:1]
|
||||
)
|
||||
keys_values.append(keys_values_layer)
|
||||
keys_values = torch.cat(keys_values, dim=0)
|
||||
decoder_state = self.final_ln(decoder_state) # (batch_count, 1, embed_count)
|
||||
logits = self.lm_head(decoder_state) # (batch_count, 1, vocab_size)
|
||||
a = self.condition_factor
|
||||
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||
|
||||
top_logits = logits.sort(descending=True)[0][:50]
|
||||
probs = torch.where(
|
||||
logits < top_logits[-1],
|
||||
torch.zeros([1]),
|
||||
torch.exp(logits - top_logits[0])
|
||||
)
|
||||
return probs, keys_values
|
||||
|
||||
|
||||
def forward(self,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor
|
||||
) -> LongTensor:
|
||||
image_tokens: List[LongTensor] = []
|
||||
keys_values_state = torch.zeros(self.keys_values_state_shape)
|
||||
image_token = self.start_token
|
||||
encoder_state = encoder_state.transpose(1, 2).unsqueeze(2)
|
||||
|
||||
for i in range(self.sample_token_count):
|
||||
token_index = torch.tensor([i]).to(torch.long)
|
||||
probs, keys_values_state = self.decode_step(
|
||||
text_tokens = text_tokens,
|
||||
encoder_state = encoder_state,
|
||||
keys_values_state = keys_values_state,
|
||||
prev_token_and_index = torch.cat([image_token, token_index])
|
||||
)
|
||||
|
||||
image_token = torch.multinomial(probs, 1)
|
||||
image_tokens += [image_token]
|
||||
|
||||
if self.is_verbose:
|
||||
token = int(image_token.detach().numpy())
|
||||
print("image token {} is {}".format(i, token))
|
||||
|
||||
return torch.cat(image_tokens)
|
143
min_dalle/models/dalle_bart_encoder_flax.py
Normal file
143
min_dalle/models/dalle_bart_encoder_flax.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from functools import partial
|
||||
import jax
|
||||
from jax import lax, numpy as jnp
|
||||
from flax import linen as nn
|
||||
|
||||
|
||||
class GLUFlax(nn.Module):
|
||||
count_in_out: int
|
||||
count_middle: int
|
||||
|
||||
def setup(self):
|
||||
self.gelu = partial(nn.gelu, approximate=False)
|
||||
self.ln0 = nn.LayerNorm(use_scale=False)
|
||||
self.ln1 = nn.LayerNorm(use_scale=False)
|
||||
self.fc0 = nn.Dense(self.count_middle, use_bias=False)
|
||||
self.fc1 = nn.Dense(self.count_middle, use_bias=False)
|
||||
self.fc2 = nn.Dense(self.count_in_out, use_bias=False)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
|
||||
z = self.ln0(z)
|
||||
z = self.ln1(self.gelu(self.fc0(z)) * self.fc1(z))
|
||||
z = self.fc2(z)
|
||||
return z
|
||||
|
||||
class AttentionFlax(nn.Module):
|
||||
head_count: int
|
||||
embed_count: int
|
||||
|
||||
def setup(self):
|
||||
self.q_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||
self.k_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||
|
||||
def forward(self,
|
||||
keys: jnp.ndarray,
|
||||
values: jnp.ndarray,
|
||||
queries: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray
|
||||
) -> jnp.ndarray:
|
||||
attention_bias: jnp.ndarray = lax.select(
|
||||
attention_mask,
|
||||
jnp.full(attention_mask.shape, 0.0),
|
||||
jnp.full(attention_mask.shape, -jnp.inf),
|
||||
)
|
||||
attention_weights: jnp.ndarray = jnp.einsum(
|
||||
'bqhd,bkhd->bhqk',
|
||||
queries,
|
||||
keys
|
||||
)
|
||||
attention_weights += attention_bias[:, None, None, :]
|
||||
attention_weights = jax.nn.softmax(attention_weights)
|
||||
attention_output: jnp.ndarray = jnp.einsum(
|
||||
"bhqk,bkhd->bqhd",
|
||||
attention_weights,
|
||||
values
|
||||
)
|
||||
shape = attention_output.shape[:2] + (self.embed_count,)
|
||||
attention_output = attention_output.reshape(shape)
|
||||
attention_output = self.out_proj(attention_output)
|
||||
return attention_output
|
||||
|
||||
class EncoderSelfAttentionFlax(AttentionFlax):
|
||||
def __call__(
|
||||
self,
|
||||
encoder_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray
|
||||
) -> jnp.ndarray:
|
||||
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
|
||||
keys = self.k_proj(encoder_state).reshape(shape_split)
|
||||
values = self.v_proj(encoder_state).reshape(shape_split)
|
||||
queries = self.q_proj(encoder_state).reshape(shape_split)
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
return self.forward(keys, values, queries, attention_mask)
|
||||
|
||||
class DalleBartEncoderLayerFlax(nn.Module):
|
||||
attention_head_count: int
|
||||
embed_count: int
|
||||
glu_embed_count: int
|
||||
|
||||
def setup(self):
|
||||
self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False)
|
||||
self.self_attn = EncoderSelfAttentionFlax(
|
||||
self.attention_head_count,
|
||||
self.embed_count
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm()
|
||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self,
|
||||
encoder_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray
|
||||
) -> jnp.ndarray:
|
||||
residual = encoder_state
|
||||
encoder_state = self.pre_self_attn_layer_norm(encoder_state)
|
||||
encoder_state = self.self_attn(encoder_state, attention_mask)
|
||||
encoder_state = self.self_attn_layer_norm(encoder_state)
|
||||
encoder_state = residual + encoder_state
|
||||
residual = encoder_state
|
||||
encoder_state = self.glu(encoder_state)
|
||||
encoder_state = residual + encoder_state
|
||||
return encoder_state, None
|
||||
|
||||
class DalleBartEncoderFlax(nn.Module):
|
||||
attention_head_count: int
|
||||
embed_count: int
|
||||
glu_embed_count: int
|
||||
text_token_count: int
|
||||
text_vocab_count: int
|
||||
layer_count: int
|
||||
|
||||
def setup(self):
|
||||
self.embed_tokens = nn.Embed(self.text_vocab_count, self.embed_count)
|
||||
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
|
||||
self.layers = nn.scan(
|
||||
DalleBartEncoderLayerFlax,
|
||||
variable_axes = { "params": 0, "cache": 0 },
|
||||
split_rngs = { "params": True },
|
||||
in_axes = nn.broadcast,
|
||||
length = self.layer_count
|
||||
)(
|
||||
self.attention_head_count,
|
||||
self.embed_count,
|
||||
self.glu_embed_count,
|
||||
name="FlaxBartEncoderLayers"
|
||||
)
|
||||
self.layernorm_embedding = nn.LayerNorm()
|
||||
self.final_ln = nn.LayerNorm(use_scale=False)
|
||||
|
||||
def __call__(self, text_tokens: jnp.ndarray) -> jnp.ndarray:
|
||||
batch_count, token_count = text_tokens.shape
|
||||
pose_tokens = jnp.tile(jnp.arange(token_count), (batch_count, 1))
|
||||
attention_mask = jnp.not_equal(text_tokens, 1)
|
||||
encoder_state = (
|
||||
self.embed_tokens(text_tokens) +
|
||||
self.embed_positions(pose_tokens)
|
||||
)
|
||||
encoder_state = self.layernorm_embedding(encoder_state)
|
||||
encoder_state, _ = self.layers(encoder_state, attention_mask)
|
||||
encoder_state = self.final_ln(encoder_state)
|
||||
return encoder_state
|
159
min_dalle/models/dalle_bart_encoder_torch.py
Normal file
159
min_dalle/models/dalle_bart_encoder_torch.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
|
||||
class GLUTorch(nn.Module):
|
||||
def __init__(self, count_in_out, count_middle):
|
||||
super().__init__()
|
||||
self.gelu = nn.GELU()
|
||||
self.ln0 = nn.LayerNorm(count_in_out)
|
||||
self.ln1 = nn.LayerNorm(count_middle)
|
||||
self.fc0 = nn.Linear(count_in_out, count_middle, bias=False)
|
||||
self.fc1 = nn.Linear(count_in_out, count_middle, bias=False)
|
||||
self.fc2 = nn.Linear(count_middle, count_in_out, bias=False)
|
||||
|
||||
def forward(self, z: FloatTensor) -> FloatTensor:
|
||||
z = self.ln0.forward(z)
|
||||
w = self.fc0.forward(z)
|
||||
w = self.gelu.forward(w)
|
||||
v = self.fc1.forward(z)
|
||||
z = self.ln1.forward(w * v)
|
||||
z = self.fc2.forward(z)
|
||||
return z
|
||||
|
||||
class AttentionTorch(nn.Module):
|
||||
def __init__(self, head_count: int, embed_count: int):
|
||||
super().__init__()
|
||||
self.head_count = head_count
|
||||
self.embed_count = embed_count
|
||||
self.head_dim = embed_count // head_count
|
||||
|
||||
self.k_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False)
|
||||
self.v_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False)
|
||||
self.q_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False)
|
||||
self.out_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False)
|
||||
|
||||
def forward(self,
|
||||
keys: FloatTensor,
|
||||
values: FloatTensor,
|
||||
queries: FloatTensor,
|
||||
attention_mask: BoolTensor
|
||||
) -> FloatTensor:
|
||||
batch_count = keys.shape[0]
|
||||
|
||||
# b(hc)1q -> bqhc
|
||||
# print(keys.shape, "keys", values.shape, "values", queries.shape, "queries")
|
||||
keys = keys.transpose(1, 3)
|
||||
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
|
||||
|
||||
# b(hc)1q -> bchq
|
||||
shape = (batch_count, self.head_count, self.head_dim, -1)
|
||||
values = values.reshape(shape)
|
||||
values = values.transpose(1, 2)
|
||||
queries = queries.reshape(shape)
|
||||
queries = queries.transpose(1, 2)
|
||||
|
||||
# print(keys.shape, "keys", values.shape, "values", queries.shape, "queries")
|
||||
|
||||
attention_bias = torch.where(
|
||||
attention_mask,
|
||||
torch.zeros([1, 1]),
|
||||
torch.ones([1, 1]) * (-torch.inf),
|
||||
)
|
||||
attention_weights: FloatTensor = torch.einsum(
|
||||
'bchq,bkhc->bkhq',
|
||||
queries / self.head_dim ** 0.5,
|
||||
keys
|
||||
)
|
||||
attention_weights += attention_bias[:, :, None, None]
|
||||
attention_weights = torch.softmax(attention_weights, 1)
|
||||
# print(attention_weights.shape, "attention_weights")
|
||||
hidden_state: FloatTensor = torch.einsum(
|
||||
"bkhq,bchk->bchq",
|
||||
attention_weights,
|
||||
values
|
||||
)
|
||||
# bchq -> b(hc)1q
|
||||
# print(hidden_state.shape, "hidden_state")
|
||||
hidden_state = hidden_state.transpose(1, 2)
|
||||
hidden_state = hidden_state.reshape(batch_count, self.embed_count, 1, -1)
|
||||
hidden_state = self.out_proj.forward(hidden_state)
|
||||
# print(hidden_state.shape, "hidden_state")
|
||||
return hidden_state
|
||||
|
||||
|
||||
class EncoderSelfAttentionTorch(AttentionTorch):
|
||||
def forward(
|
||||
self,
|
||||
encoder_state: FloatTensor,
|
||||
attention_mask: BoolTensor
|
||||
) -> FloatTensor:
|
||||
encoder_state = encoder_state.transpose(1, 2).unsqueeze(2)
|
||||
# print(encoder_state.shape, "encoder_state")
|
||||
keys = self.k_proj.forward(encoder_state)
|
||||
values = self.v_proj.forward(encoder_state)
|
||||
queries = self.q_proj.forward(encoder_state)
|
||||
return super().forward(keys, values, queries, attention_mask)
|
||||
|
||||
|
||||
class EncoderLayerTorch(nn.Module):
|
||||
def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
|
||||
super().__init__()
|
||||
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.self_attn = EncoderSelfAttentionTorch(head_count, embed_count)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.glu = GLUTorch(embed_count, glu_embed_count)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_state: FloatTensor,
|
||||
attention_mask: BoolTensor
|
||||
) -> FloatTensor:
|
||||
residual = encoder_state
|
||||
encoder_state = self.pre_self_attn_layer_norm.forward(encoder_state)
|
||||
encoder_state = self.self_attn.forward(encoder_state, attention_mask)
|
||||
encoder_state = encoder_state.transpose(1, 3).squeeze(2)
|
||||
encoder_state = self.self_attn_layer_norm.forward(encoder_state)
|
||||
encoder_state = residual + encoder_state
|
||||
residual = encoder_state
|
||||
encoder_state = self.glu.forward(encoder_state)
|
||||
encoder_state = residual + encoder_state
|
||||
return encoder_state
|
||||
|
||||
|
||||
class DalleBartEncoderTorch(nn.Module):
|
||||
def __init__(self,
|
||||
layer_count: int,
|
||||
embed_count: int,
|
||||
attention_head_count: int,
|
||||
text_vocab_count: int,
|
||||
text_token_count: int,
|
||||
glu_embed_count: int
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
|
||||
self.embed_positions = nn.Embedding(text_token_count, embed_count)
|
||||
self.layers: List[EncoderLayerTorch] = nn.ModuleList([
|
||||
EncoderLayerTorch(
|
||||
embed_count = embed_count,
|
||||
head_count = attention_head_count,
|
||||
glu_embed_count = glu_embed_count
|
||||
)
|
||||
for _ in range(layer_count)
|
||||
])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||
self.final_ln = nn.LayerNorm(embed_count)
|
||||
|
||||
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
batch_count, token_count = text_tokens.shape
|
||||
pose_tokens = torch.stack([torch.arange(token_count)] * batch_count)
|
||||
encoder_state = (
|
||||
self.embed_tokens.forward(text_tokens) +
|
||||
self.embed_positions.forward(pose_tokens)
|
||||
)
|
||||
encoder_state = self.layernorm_embedding.forward(encoder_state)
|
||||
for layer in self.layers:
|
||||
encoder_state = layer.forward(encoder_state, attention_mask)
|
||||
encoder_state = self.final_ln.forward(encoder_state)
|
||||
return encoder_state
|
175
min_dalle/models/vqgan_detokenizer.py
Normal file
175
min_dalle/models/vqgan_detokenizer.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||
|
||||
batch_size: int = 1
|
||||
|
||||
class ResnetBlock(Module):
|
||||
def __init__(self, log2_count_in: int, log2_count_out: int):
|
||||
super().__init__()
|
||||
m, n = 2 ** log2_count_in, 2 ** log2_count_out
|
||||
self.is_middle = m == n
|
||||
self.norm1 = GroupNorm(2 ** 5, m)
|
||||
self.conv1 = Conv2d(m, n, 3, padding=1)
|
||||
self.norm2 = GroupNorm(2 ** 5, n)
|
||||
self.conv2 = Conv2d(n, n, 3, padding=1)
|
||||
if not self.is_middle:
|
||||
self.nin_shortcut = Conv2d(m, n, 1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
h = x
|
||||
h = self.norm1.forward(h)
|
||||
h *= torch.sigmoid(h)
|
||||
h = self.conv1.forward(h)
|
||||
h = self.norm2.forward(h)
|
||||
h *= torch.sigmoid(h)
|
||||
h = self.conv2(h)
|
||||
if not self.is_middle:
|
||||
x = self.nin_shortcut.forward(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
n = 2 ** 9
|
||||
self.norm = GroupNorm(2 ** 5, n)
|
||||
self.q = Conv2d(n, n, 1)
|
||||
self.k = Conv2d(n, n, 1)
|
||||
self.v = Conv2d(n, n, 1)
|
||||
self.proj_out = Conv2d(n, n, 1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
n = 2 ** 9
|
||||
h = x
|
||||
h = self.norm(h)
|
||||
q = self.q.forward(h)
|
||||
k = self.k.forward(h)
|
||||
v = self.v.forward(h)
|
||||
q = q.reshape(batch_size, n, 2 ** 8)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(batch_size, n, 2 ** 8)
|
||||
w = torch.bmm(q, k)
|
||||
w /= n ** 0.5
|
||||
w = torch.softmax(w, dim=2)
|
||||
v = v.reshape(batch_size, n, 2 ** 8)
|
||||
w = w.permute(0, 2, 1)
|
||||
h = torch.bmm(v, w)
|
||||
h = h.reshape(batch_size, n, 2 ** 4, 2 ** 4)
|
||||
h = self.proj_out.forward(h)
|
||||
return x + h
|
||||
|
||||
class MiddleLayer(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.block_1 = ResnetBlock(9, 9)
|
||||
self.attn_1 = AttentionBlock()
|
||||
self.block_2 = ResnetBlock(9, 9)
|
||||
|
||||
def forward(self, h: Tensor) -> Tensor:
|
||||
h = self.block_1.forward(h)
|
||||
h = self.attn_1.forward(h)
|
||||
h = self.block_2.forward(h)
|
||||
return h
|
||||
|
||||
class Upsample(Module):
|
||||
def __init__(self, log2_count):
|
||||
super().__init__()
|
||||
n = 2 ** log2_count
|
||||
self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
|
||||
self.conv = Conv2d(n, n, 3, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.upsample.forward(x)
|
||||
x = self.conv.forward(x)
|
||||
return x
|
||||
|
||||
class UpsampleBlock(Module):
|
||||
def __init__(
|
||||
self,
|
||||
log2_count_in: int,
|
||||
log2_count_out: int,
|
||||
has_attention: bool,
|
||||
has_upsample: bool
|
||||
):
|
||||
super().__init__()
|
||||
self.has_attention = has_attention
|
||||
self.has_upsample = has_upsample
|
||||
self.block = ModuleList([
|
||||
ResnetBlock(log2_count_in, log2_count_out),
|
||||
ResnetBlock(log2_count_out, log2_count_out),
|
||||
ResnetBlock(log2_count_out, log2_count_out)
|
||||
])
|
||||
if has_attention:
|
||||
self.attn = ModuleList([
|
||||
AttentionBlock(),
|
||||
AttentionBlock(),
|
||||
AttentionBlock()
|
||||
])
|
||||
else:
|
||||
self.attn = ModuleList()
|
||||
|
||||
if has_upsample:
|
||||
self.upsample = Upsample(log2_count_out)
|
||||
|
||||
|
||||
def forward(self, h: Tensor) -> Tensor:
|
||||
for j in range(3):
|
||||
h = self.block[j].forward(h)
|
||||
if self.has_attention:
|
||||
h = self.attn[j].forward(h)
|
||||
if self.has_upsample:
|
||||
h = self.upsample.forward(h)
|
||||
return h
|
||||
|
||||
class Decoder(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = Conv2d(2 ** 8, 2 ** 9, 3, padding=1)
|
||||
self.mid = MiddleLayer()
|
||||
|
||||
self.up = ModuleList([
|
||||
UpsampleBlock(7, 7, False, False),
|
||||
UpsampleBlock(8, 7, False, True),
|
||||
UpsampleBlock(8, 8, False, True),
|
||||
UpsampleBlock(9, 8, False, True),
|
||||
UpsampleBlock(9, 9, True, True)
|
||||
])
|
||||
|
||||
self.norm_out = GroupNorm(2 ** 5, 2 ** 7)
|
||||
self.conv_out = Conv2d(2 ** 7, 3, 3, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
z = self.conv_in.forward(z)
|
||||
z = self.mid.forward(z)
|
||||
|
||||
for i in reversed(range(5)):
|
||||
z = self.up[i].forward(z)
|
||||
|
||||
z = self.norm_out.forward(z)
|
||||
z *= torch.sigmoid(z)
|
||||
z = self.conv_out.forward(z)
|
||||
return z
|
||||
|
||||
class VQGanDetokenizer(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
m, n = 2 ** 14, 2 ** 8
|
||||
self.embedding = Embedding(m, n)
|
||||
self.post_quant_conv = Conv2d(n, n, 1)
|
||||
self.decoder = Decoder()
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
z = self.embedding.forward(z)
|
||||
z = z.view((batch_size, 2 ** 4, 2 ** 4, 2 ** 8))
|
||||
z = z.permute(0, 3, 1, 2).contiguous()
|
||||
z = self.post_quant_conv.forward(z)
|
||||
z = self.decoder.forward(z)
|
||||
z = z.permute(0, 2, 3, 1)
|
||||
# z = torch.concat((
|
||||
# torch.concat((z[0], z[1]), axis=1),
|
||||
# torch.concat((z[2], z[3]), axis=1)
|
||||
# ), axis=0)
|
||||
z = z.clip(0.0, 1.0) * 255
|
||||
return z[0]
|
Reference in New Issue
Block a user