moved flax model and conversion code to separate repository

main
Brett Kuprel 2 years ago
parent febd18df77
commit 07ce93d5f8
  1. 10
      README.md
  2. 27
      image_from_text.py
  3. 136
      min_dalle/load_params.py
  4. 32
      min_dalle/min_dalle_base.py
  5. 87
      min_dalle/min_dalle_flax.py
  6. 33
      min_dalle/min_dalle_torch.py
  7. 247
      min_dalle/models/dalle_bart_decoder_flax.py
  8. 151
      min_dalle/models/dalle_bart_encoder_flax.py
  9. 0
      replicate/cog.yaml
  10. 0
      replicate/predict.py
  11. 3
      requirements_flax.txt
  12. 29
      setup.sh
  13. 14
      setup_flax.sh

10
README.md vendored

@ -3,21 +3,19 @@
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)  
[![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch.
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini) in PyTorch. It has been stripped to the bare essentials necessary for doing inference. The only third party dependencies are numpy and torch.
It currently takes **7.4 seconds** to generate an image with DALL·E Mega with PyTorch on a standard GPU runtime in Colab
The flax model, and the code for coverting it to torch, has been moved [here](https://github.com/kuprel/min-dalle-flax).
### Setup
Run `sh setup.sh` to install dependencies and download pretrained models. The torch models can be manually downloaded [here](https://huggingface.co/kuprel/min-dalle/tree/main).
The flax models can be manually downloaded here:
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
### Usage
Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize either `MinDalleTorch` or `MinDalleFlax`, then call `generate_image` with some text and a seed. See the colab for an example.
Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize `MinDalleTorch`, then call `generate_image` with some text and a seed. See the colab for an example.
### Examples

@ -3,15 +3,11 @@ import os
from PIL import Image
from min_dalle.min_dalle_torch import MinDalleTorch
from min_dalle.min_dalle_flax import MinDalleFlax
parser = argparse.ArgumentParser()
parser.add_argument('--mega', action='store_true')
parser.add_argument('--no-mega', dest='mega', action='store_false')
parser.set_defaults(mega=False)
parser.add_argument('--torch', action='store_true')
parser.add_argument('--no-torch', dest='torch', action='store_false')
parser.set_defaults(torch=True)
parser.add_argument('--text', type=str, default='cat')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--image_path', type=str, default='generated')
@ -37,7 +33,6 @@ def save_image(image: Image.Image, path: str):
def generate_image(
is_torch: bool,
is_mega: bool,
text: str,
seed: int,
@ -45,29 +40,21 @@ def generate_image(
token_count: int
):
is_reusable = False
if is_torch:
image_generator = MinDalleTorch(is_mega, is_reusable, token_count)
if token_count < 256:
image_tokens = image_generator.generate_image_tokens(text, seed)
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
return
else:
image = image_generator.generate_image(text, seed)
model = MinDalleTorch(is_mega, is_reusable, token_count)
if token_count < 256:
image_tokens = model.generate_image_tokens(text, seed)
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
else:
image_generator = MinDalleFlax(is_mega, is_reusable)
image = image_generator.generate_image(text, seed)
save_image(image, image_path)
print(ascii_from_image(image, size=128))
image = model.generate_image(text, seed)
save_image(image, image_path)
print(ascii_from_image(image, size=128))
if __name__ == '__main__':
args = parser.parse_args()
print(args)
generate_image(
is_torch=args.torch,
is_mega=args.mega,
text=args.text,
seed=args.seed,

@ -1,136 +0,0 @@
import os
import numpy
from typing import Dict
from flax.traverse_util import flatten_dict
from flax.serialization import msgpack_restore
import torch
torch.set_grad_enabled(False)
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
params: Dict[str, numpy.ndarray] = msgpack_restore(f.read())
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
for i in list(P.keys()):
j = i
if 'up' in i or 'down' in i:
j = i.replace('_', '.')
j = j.replace('proj.out', 'proj_out')
j = j.replace('nin.short', 'nin_short')
if 'bias' in i:
P[j] = P.pop(i)
elif 'scale' in i:
j = j.replace('scale', 'weight')
P[j] = P.pop(i)
elif 'kernel' in i:
j = j.replace('kernel', 'weight')
P[j] = P.pop(i).transpose(3, 2, 0, 1)
for i in P:
P[i] = torch.tensor(P[i])
P['embedding.weight'] = P.pop('quantize.embedding.embedding')
for i in list(P):
if i.split('.')[0] in ['encoder', 'quant_conv']:
P.pop(i)
return P
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
params = msgpack_restore(f.read())
for codec in ['encoder', 'decoder']:
k = 'FlaxBart{}Layers'.format(codec.title())
P: dict = params['model'][codec]['layers'][k]
P['pre_self_attn_layer_norm'] = P.pop('LayerNorm_0')
P['self_attn_layer_norm'] = P.pop('LayerNorm_1')
P['self_attn'] = P.pop('FlaxBartAttention_0')
if codec == 'decoder':
P['pre_encoder_attn_layer_norm'] = P.pop('LayerNorm_2')
P['encoder_attn_layer_norm'] = P.pop('LayerNorm_3')
P['encoder_attn'] = P.pop('FlaxBartAttention_1')
P['glu']: dict = P.pop('GLU_0')
P['glu']['ln0'] = P['glu'].pop('LayerNorm_0')
P['glu']['ln1'] = P['glu'].pop('LayerNorm_1')
P['glu']['fc0'] = P['glu'].pop('Dense_0')
P['glu']['fc1'] = P['glu'].pop('Dense_1')
P['glu']['fc2'] = P['glu'].pop('Dense_2')
for codec in ['encoder', 'decoder']:
layers_params = params['model'][codec].pop('layers')
params['model'][codec] = {
**params['model'][codec],
**layers_params
}
model_params = params.pop('model')
params = {**params, **model_params}
params['decoder']['lm_head'] = params.pop('lm_head')
return params
def convert_dalle_bart_torch_from_flax_params(
params: dict,
layer_count: int,
is_encoder: bool
) -> dict:
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
for i in P:
P[i] = torch.tensor(P[i]).to(torch.float16)
for i in list(P):
if 'kernel' in i:
j = i.replace('kernel', 'weight')
P[j] = P.pop(i).transpose(-1, -2)
elif 'scale' in i:
j = i.replace('scale', 'weight')
P[j] = P.pop(i)
for i in list(P):
j = 'FlaxBart{}Layers'.format('Encoder' if is_encoder else 'Decoder')
if j in i:
for l in range(layer_count):
k = i.replace(j, 'layers.' + str(l))
P[k] = P[i][l]
P.pop(i)
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
P['embed_positions.weight'] = P.pop('embed_positions.embedding')
return P
def convert_and_save_torch_params(is_mega: bool, model_path: str):
print("converting params to torch")
layer_count = 24 if is_mega else 12
flax_params = load_dalle_bart_flax_params(model_path)
encoder_params = convert_dalle_bart_torch_from_flax_params(
flax_params['encoder'],
layer_count=layer_count,
is_encoder=True
)
decoder_params = convert_dalle_bart_torch_from_flax_params(
flax_params['decoder'],
layer_count=layer_count,
is_encoder=False
)
for i in decoder_params:
decoder_params[i] = decoder_params[i].to(torch.float16)
for i in encoder_params:
encoder_params[i] = encoder_params[i].to(torch.float16)
detoker_params = load_vqgan_torch_params('./pretrained/vqgan')
detoker_path = os.path.join('pretrained', 'vqgan', 'detoker.pt')
torch.save(encoder_params, os.path.join(model_path, 'encoder.pt'))
torch.save(decoder_params, os.path.join(model_path, 'decoder.pt'))
torch.save(detoker_params, detoker_path)

@ -1,32 +0,0 @@
import os
import json
import numpy
from .text_tokenizer import TextTokenizer
class MinDalleBase:
def __init__(self, is_mega: bool):
self.is_mega = is_mega
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
self.model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(self.model_path))
vocab_path = os.path.join(self.model_path, 'vocab.json')
merges_path = os.path.join(self.model_path, 'merges.txt')
with open(vocab_path, 'r', encoding='utf8') as f:
vocab = json.load(f)
with open(merges_path, 'r', encoding='utf8') as f:
merges = f.read().split("\n")[1:-1]
self.tokenizer = TextTokenizer(vocab, merges)
def tokenize_text(self, text: str) -> numpy.ndarray:
print("tokenizing text")
tokens = self.tokenizer.tokenize(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
text_tokens[0, :2] = [tokens[0], tokens[-1]]
text_tokens[1, :len(tokens)] = tokens
return text_tokens

@ -1,87 +0,0 @@
import jax
import numpy
from PIL import Image
import torch
from .min_dalle_base import MinDalleBase
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
from .models.vqgan_detokenizer import VQGanDetokenizer
from .load_params import load_dalle_bart_flax_params, load_vqgan_torch_params
class MinDalleFlax(MinDalleBase):
def __init__(self, is_mega: bool, is_reusable: bool = True):
super().__init__(is_mega)
self.is_reusable = is_reusable
print("initializing MinDalleFlax")
self.model_params = load_dalle_bart_flax_params(self.model_path)
if is_reusable:
self.init_encoder()
self.init_decoder()
self.init_detokenizer()
def init_encoder(self):
print("initializing DalleBartEncoderFlax")
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = 32 if self.is_mega else 16,
embed_count = 2048 if self.is_mega else 1024,
glu_embed_count = 4096 if self.is_mega else 2730,
text_token_count = 64,
text_vocab_count = 50272 if self.is_mega else 50264,
layer_count = 24 if self.is_mega else 12
).bind({'params': self.model_params.pop('encoder')})
def init_decoder(self):
print("initializing DalleBartDecoderFlax")
self.decoder = DalleBartDecoderFlax(
image_token_count = 256,
image_vocab_count = 16415 if self.is_mega else 16384,
attention_head_count = 32 if self.is_mega else 16,
embed_count = 2048 if self.is_mega else 1024,
glu_embed_count = 4096 if self.is_mega else 2730,
layer_count = 24 if self.is_mega else 12,
start_token = 16415 if self.is_mega else 16384
)
def init_detokenizer(self):
print("initializing VQGanDetokenizer")
params = load_vqgan_torch_params('./pretrained/vqgan')
self.detokenizer = VQGanDetokenizer()
self.detokenizer.load_state_dict(params)
del params
def generate_image(self, text: str, seed: int) -> Image.Image:
text_tokens = self.tokenize_text(text)
if not self.is_reusable: self.init_encoder()
print("encoding text tokens")
encoder_state = self.encoder(text_tokens)
if not self.is_reusable: del self.encoder
if not self.is_reusable:
self.init_decoder()
params = self.model_params.pop('decoder')
else:
params = self.model_params['decoder']
print("sampling image tokens")
image_tokens = self.decoder.sample_image_tokens(
text_tokens,
encoder_state,
jax.random.PRNGKey(seed),
params
)
if not self.is_reusable: del self.decoder
image_tokens = torch.tensor(numpy.array(image_tokens))
if not self.is_reusable: self.init_detokenizer()
print("detokenizing image")
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
if not self.is_reusable: del self.detokenizer
image = Image.fromarray(image.to('cpu').detach().numpy())
return image

@ -1,18 +1,20 @@
import os
from PIL import Image
from typing import Dict
import numpy
from torch import LongTensor
import torch
import json
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
from .min_dalle_base import MinDalleBase
from .text_tokenizer import TextTokenizer
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
from .models.vqgan_detokenizer import VQGanDetokenizer
class MinDalleTorch(MinDalleBase):
class MinDalleTorch:
def __init__(
self,
is_mega: bool,
@ -20,7 +22,20 @@ class MinDalleTorch(MinDalleBase):
token_count: int = 256
):
print("initializing MinDalleTorch")
super().__init__(is_mega)
self.is_mega = is_mega
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
self.model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(self.model_path))
vocab_path = os.path.join(self.model_path, 'vocab.json')
merges_path = os.path.join(self.model_path, 'merges.txt')
with open(vocab_path, 'r', encoding='utf8') as f:
vocab = json.load(f)
with open(merges_path, 'r', encoding='utf8') as f:
merges = f.read().split("\n")[1:-1]
self.tokenizer = TextTokenizer(vocab, merges)
self.is_reusable = is_reusable
self.token_count = token_count
@ -76,7 +91,17 @@ class MinDalleTorch(MinDalleBase):
self.detokenizer.load_state_dict(params)
del params
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
def tokenize_text(self, text: str) -> numpy.ndarray:
print("tokenizing text")
tokens = self.tokenizer.tokenize(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
text_tokens[0, :2] = [tokens[0], tokens[-1]]
text_tokens[1, :len(tokens)] = tokens
return text_tokens
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
text_tokens = self.tokenize_text(text)

@ -1,247 +0,0 @@
import jax, flax
from jax import lax, numpy as jnp
from flax import linen as nn
from typing import Tuple
from .dalle_bart_encoder_flax import GLUFlax, AttentionFlax
class DecoderCrossAttentionFlax(AttentionFlax):
def __call__(
self,
decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray,
) -> jnp.ndarray:
keys = self.k_proj(encoder_state)
values = self.v_proj(encoder_state)
queries = self.q_proj(decoder_state)
return self.forward(keys, values, queries, attention_mask)
class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(
self,
decoder_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
state_index: tuple
) -> Tuple[jnp.ndarray, jnp.ndarray]:
keys = self.k_proj(decoder_state)
values = self.v_proj(decoder_state)
queries = self.q_proj(decoder_state)
attention_state = lax.dynamic_update_slice(
attention_state,
jnp.concatenate([keys, values]).astype(jnp.float32),
state_index
)
batch_count = decoder_state.shape[0]
keys = attention_state[:batch_count]
values = attention_state[batch_count:]
decoder_state = self.forward(
keys,
values,
queries,
attention_mask
).astype(decoder_state.dtype)
return decoder_state, attention_state
class DalleBartDecoderLayerFlax(nn.Module):
image_token_count: int
attention_head_count: int
embed_count: int
glu_embed_count: int
def setup(self):
self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False)
self.self_attn = DecoderSelfAttentionFlax(
self.attention_head_count,
self.embed_count
)
self.self_attn_layer_norm = nn.LayerNorm()
self.pre_encoder_attn_layer_norm = nn.LayerNorm(use_scale=False)
self.encoder_attn = DecoderCrossAttentionFlax(
self.attention_head_count,
self.embed_count,
)
self.encoder_attn_layer_norm = nn.LayerNorm()
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(
self,
decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
token_index: int
) -> Tuple[jnp.ndarray, jnp.ndarray]:
# Self Attention
residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
self_attention_mask = jnp.tile(
jnp.arange(self.image_token_count) < token_index + 1,
(decoder_state.shape[0], 1)
)
decoder_state, attention_state = self.self_attn(
decoder_state,
attention_state,
self_attention_mask,
(0, token_index, 0)
)
decoder_state = self.self_attn_layer_norm(decoder_state)
decoder_state = residual + decoder_state
# Cross Attention
residual = decoder_state
decoder_state = self.pre_encoder_attn_layer_norm(decoder_state)
decoder_state = self.encoder_attn(
decoder_state,
encoder_state,
attention_mask
)
decoder_state = self.encoder_attn_layer_norm(decoder_state)
decoder_state = residual + decoder_state
# Feed forward
residual = decoder_state
decoder_state = self.glu(decoder_state)
decoder_state = residual + decoder_state
return decoder_state, attention_state
@flax.struct.dataclass
class SampleState:
prev_token: jnp.ndarray
prng_key: jnp.ndarray
attention_state: jnp.ndarray
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
return (1 - a) * logits[0, -1] + a * logits[1, -1]
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
top_logits, _ = lax.top_k(logits, k)
suppressed = -jnp.inf * jnp.ones_like(logits)
return lax.select(logits < top_logits[-1], suppressed, logits)
class DalleBartDecoderFlax(nn.Module):
image_token_count: int
image_vocab_count: int
attention_head_count: int
embed_count: int
glu_embed_count: int
layer_count: int
start_token: int
def setup(self):
self.embed_tokens = nn.Embed(
self.image_vocab_count + 1,
self.embed_count
)
self.embed_positions = nn.Embed(
self.image_token_count,
self.embed_count
)
self.layers = nn.scan(
DalleBartDecoderLayerFlax,
variable_axes = { "params": 0 },
split_rngs = { "params": True },
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
out_axes = 0,
length=self.layer_count,
)(
self.image_token_count,
self.attention_head_count,
self.embed_count,
self.glu_embed_count,
name="FlaxBartDecoderLayers"
)
self.layernorm_embedding = nn.LayerNorm()
self.final_ln = nn.LayerNorm(use_scale=False)
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
def __call__(
self,
encoder_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
prev_token: int,
token_index: int
) -> Tuple[jnp.ndarray, jnp.ndarray]:
batch_count = encoder_state.shape[0]
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
decoder_state = self.embed_tokens(prev_token * ones)
decoder_state += self.embed_positions(token_index * ones)
decoder_state = self.layernorm_embedding(decoder_state)
decoder_state, attention_state = self.layers(
decoder_state,
encoder_state,
attention_state,
attention_mask,
token_index
)
decoder_state = self.final_ln(decoder_state)
decoder_state = self.lm_head(decoder_state)
return decoder_state, attention_state
def sample_image_tokens(
self,
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
prng_key: jax.random.PRNGKey,
params: dict
) -> jnp.ndarray:
attention_mask = jnp.not_equal(text_tokens, 1)
def sample_next_image_token(
state: SampleState,
token_index: int
) -> Tuple[SampleState, jnp.ndarray]:
logits, attention_state = self.apply(
{ 'params': params },
encoder_state = encoder_state,
attention_state = state.attention_state,
attention_mask = attention_mask,
prev_token = state.prev_token,
token_index = token_index
)
logits = super_conditioned(logits, 10.0)
logits = keep_top_k(logits, k=50)
prng_key, prng_key_next = jax.random.split(state.prng_key)
next_token = jax.random.categorical(prng_key, logits, axis=-1)
state = SampleState(
prev_token = next_token,
prng_key = prng_key_next,
attention_state = attention_state
)
return state, next_token
batch_count = encoder_state.shape[0]
attention_state_shape = (
self.layer_count,
batch_count * 2,
self.image_token_count,
self.embed_count
)
initial_state = SampleState(
prev_token = self.start_token,
prng_key = prng_key,
attention_state = jnp.zeros(attention_state_shape)
)
_, image_tokens = lax.scan(
sample_next_image_token,
initial_state,
jnp.arange(self.image_token_count)
)
return image_tokens

@ -1,151 +0,0 @@
from functools import partial
import jax
from jax import lax, numpy as jnp
from flax import linen as nn
class GLUFlax(nn.Module):
count_in_out: int
count_middle: int
def setup(self):
self.gelu = partial(nn.gelu, approximate=False)
self.ln0 = nn.LayerNorm(use_scale=False)
self.ln1 = nn.LayerNorm(use_scale=False)
self.fc0 = nn.Dense(self.count_middle, use_bias=False)
self.fc1 = nn.Dense(self.count_middle, use_bias=False)
self.fc2 = nn.Dense(self.count_in_out, use_bias=False)
@nn.compact
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
z = self.ln0(z)
z = self.ln1(self.gelu(self.fc0(z)) * self.fc1(z))
z = self.fc2(z)
return z
class AttentionFlax(nn.Module):
head_count: int
embed_count: int
def setup(self):
self.q_proj = nn.Dense(self.embed_count, use_bias=False)
self.k_proj = nn.Dense(self.embed_count, use_bias=False)
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
def forward(
self,
keys: jnp.ndarray,
values: jnp.ndarray,
queries: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias: jnp.ndarray = lax.select(
attention_mask,
jnp.full(attention_mask.shape, 0.0),
jnp.full(attention_mask.shape, -jnp.inf),
)
attention_weights: jnp.ndarray = jnp.einsum(
'bqhd,bkhd->bhqk',
queries,
keys
)
attention_weights += attention_bias[:, None, None, :]
attention_weights = jax.nn.softmax(attention_weights)
attention_output: jnp.ndarray = jnp.einsum(
"bhqk,bkhd->bqhd",
attention_weights,
values
)
shape = attention_output.shape[:2] + (self.embed_count,)
attention_output = attention_output.reshape(shape)
attention_output = self.out_proj(attention_output)
return attention_output
class EncoderSelfAttentionFlax(AttentionFlax):
def __call__(
self,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
keys = self.k_proj(encoder_state)
values = self.v_proj(encoder_state)
queries = self.q_proj(encoder_state)
return self.forward(keys, values, queries, attention_mask)
class DalleBartEncoderLayerFlax(nn.Module):
attention_head_count: int
embed_count: int
glu_embed_count: int
def setup(self):
self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False)
self.self_attn = EncoderSelfAttentionFlax(
self.attention_head_count,
self.embed_count
)
self.self_attn_layer_norm = nn.LayerNorm()
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(
self,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
residual = encoder_state
encoder_state = self.pre_self_attn_layer_norm(encoder_state)
encoder_state = self.self_attn(encoder_state, attention_mask)
encoder_state = self.self_attn_layer_norm(encoder_state)
encoder_state = residual + encoder_state
residual = encoder_state
encoder_state = self.glu(encoder_state)
encoder_state = residual + encoder_state
return encoder_state, None
class DalleBartEncoderFlax(nn.Module):
attention_head_count: int
embed_count: int
glu_embed_count: int
text_token_count: int
text_vocab_count: int
layer_count: int
def setup(self):
self.embed_tokens = nn.Embed(self.text_vocab_count, self.embed_count)
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
self.layers = nn.scan(
DalleBartEncoderLayerFlax,
variable_axes = { "params": 0 },
split_rngs = { "params": True },
in_axes = nn.broadcast,
length = self.layer_count
)(
self.attention_head_count,
self.embed_count,
self.glu_embed_count,
name="FlaxBartEncoderLayers"
)
self.layernorm_embedding = nn.LayerNorm()
self.final_ln = nn.LayerNorm(use_scale=False)
def __call__(self, text_tokens: jnp.ndarray) -> jnp.ndarray:
batch_count, token_count = text_tokens.shape
pose_tokens = jnp.tile(jnp.arange(token_count), (batch_count, 1))
attention_mask = jnp.not_equal(text_tokens, 1)
encoder_state = (
self.embed_tokens(text_tokens) +
self.embed_positions(pose_tokens)
)
encoder_state = self.layernorm_embedding(encoder_state)
encoder_state, _ = self.layers(encoder_state, attention_mask)
encoder_state = self.final_ln(encoder_state)
return encoder_state

@ -1,3 +0,0 @@
flax
torch
wandb

29
setup.sh vendored

@ -4,17 +4,22 @@ set -e
pip3 install -r requirements.txt
mkdir -p ./pretrained/dalle_bart_mega/
curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab.json -L --output ./pretrained/dalle_bart_mega/vocab.json
curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges.txt -L --output ./pretrained/dalle_bart_mega/merges.txt
curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder.pt -L --output ./pretrained/dalle_bart_mega/encoder.pt
curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder.pt -L --output ./pretrained/dalle_bart_mega/decoder.pt
repo_path="https://huggingface.co/kuprel/min-dalle/resolve/main"
mkdir -p ./pretrained/dalle_bart_mini/
curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab_mini.json -L --output ./pretrained/dalle_bart_mini/vocab.json
curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges_mini.txt -L --output ./pretrained/dalle_bart_mini/merges.txt
curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder_mini.pt -L --output ./pretrained/dalle_bart_mini/encoder.pt
curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder_mini.pt -L --output ./pretrained/dalle_bart_mini/decoder.pt
mini_path="./pretrained/dalle_bart_mini"
mega_path="./pretrained/dalle_bart_mega"
vqgan_path="./pretrained/vqgan"
mkdir -p ./pretrained/vqgan/
curl https://huggingface.co/kuprel/min-dalle/resolve/main/detoker.pt -L --output ./pretrained/vqgan/detoker.pt
mkdir -p ${vqgan_path}
mkdir -p ${mini_path}
mkdir -p ${mega_path}
curl ${repo_path}/detoker.pt -L --output ${vqgan_path}/detoker.pt
curl ${repo_path}/vocab_mini.json -L --output ${mini_path}/vocab.json
curl ${repo_path}/merges_mini.txt -L --output ${mini_path}/merges.txt
curl ${repo_path}/encoder_mini.pt -L --output ${mini_path}/encoder.pt
curl ${repo_path}/decoder_mini.pt -L --output ${mini_path}/decoder.pt
curl ${repo_path}/vocab.json -L --output ${mega_path}/vocab.json
curl ${repo_path}/merges.txt -L --output ${mega_path}/merges.txt
curl ${repo_path}/encoder.pt -L --output ${mega_path}/encoder.pt
curl ${repo_path}/decoder.pt -L --output ${mega_path}/decoder.pt

14
setup_flax.sh vendored

@ -1,14 +0,0 @@
#!/bin/bash
set -e
pip3 install -r requirements_flax.txt
# download vqgan
mkdir -p pretrained/vqgan
curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output ./pretrained/vqgan/flax_model.msgpack
# download dalle-mini and dalle-mega
python3 -m wandb login --anonymously
python3 -m wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
python3 -m wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14
Loading…
Cancel
Save