From 07ce93d5f89ff9b467659fc99e4361e63a68a8db Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Fri, 1 Jul 2022 14:06:50 -0400 Subject: [PATCH] moved flax model and conversion code to separate repository --- README.md | 10 +- image_from_text.py | 27 +-- min_dalle/load_params.py | 136 ----------- min_dalle/min_dalle_base.py | 32 --- min_dalle/min_dalle_flax.py | 87 ------- min_dalle/min_dalle_torch.py | 33 ++- min_dalle/models/dalle_bart_decoder_flax.py | 247 -------------------- min_dalle/models/dalle_bart_encoder_flax.py | 151 ------------ cog.yaml => replicate/cog.yaml | 0 predict.py => replicate/predict.py | 0 requirements_flax.txt | 3 - setup.sh | 29 ++- setup_flax.sh | 14 -- 13 files changed, 57 insertions(+), 712 deletions(-) delete mode 100644 min_dalle/load_params.py delete mode 100644 min_dalle/min_dalle_base.py delete mode 100644 min_dalle/min_dalle_flax.py delete mode 100644 min_dalle/models/dalle_bart_decoder_flax.py delete mode 100644 min_dalle/models/dalle_bart_encoder_flax.py rename cog.yaml => replicate/cog.yaml (100%) rename predict.py => replicate/predict.py (100%) delete mode 100644 requirements_flax.txt delete mode 100644 setup_flax.sh diff --git a/README.md b/README.md index 158b2d8..dcad6b9 100644 --- a/README.md +++ b/README.md @@ -3,21 +3,19 @@ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)   [![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle) -This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. +This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini) in PyTorch. It has been stripped to the bare essentials necessary for doing inference. The only third party dependencies are numpy and torch. It currently takes **7.4 seconds** to generate an image with DALL·E Mega with PyTorch on a standard GPU runtime in Colab +The flax model, and the code for coverting it to torch, has been moved [here](https://github.com/kuprel/min-dalle-flax). + ### Setup Run `sh setup.sh` to install dependencies and download pretrained models. The torch models can be manually downloaded [here](https://huggingface.co/kuprel/min-dalle/tree/main). -The flax models can be manually downloaded here: -[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384), -[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files), -[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files) ### Usage -Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize either `MinDalleTorch` or `MinDalleFlax`, then call `generate_image` with some text and a seed. See the colab for an example. +Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize `MinDalleTorch`, then call `generate_image` with some text and a seed. See the colab for an example. ### Examples diff --git a/image_from_text.py b/image_from_text.py index f300c9c..183dc5e 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -3,15 +3,11 @@ import os from PIL import Image from min_dalle.min_dalle_torch import MinDalleTorch -from min_dalle.min_dalle_flax import MinDalleFlax parser = argparse.ArgumentParser() parser.add_argument('--mega', action='store_true') parser.add_argument('--no-mega', dest='mega', action='store_false') parser.set_defaults(mega=False) -parser.add_argument('--torch', action='store_true') -parser.add_argument('--no-torch', dest='torch', action='store_false') -parser.set_defaults(torch=True) parser.add_argument('--text', type=str, default='cat') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--image_path', type=str, default='generated') @@ -37,7 +33,6 @@ def save_image(image: Image.Image, path: str): def generate_image( - is_torch: bool, is_mega: bool, text: str, seed: int, @@ -45,29 +40,21 @@ def generate_image( token_count: int ): is_reusable = False - if is_torch: - image_generator = MinDalleTorch(is_mega, is_reusable, token_count) - - if token_count < 256: - image_tokens = image_generator.generate_image_tokens(text, seed) - print('image tokens', list(image_tokens.to('cpu').detach().numpy())) - return - else: - image = image_generator.generate_image(text, seed) + model = MinDalleTorch(is_mega, is_reusable, token_count) + if token_count < 256: + image_tokens = model.generate_image_tokens(text, seed) + print('image tokens', list(image_tokens.to('cpu').detach().numpy())) else: - image_generator = MinDalleFlax(is_mega, is_reusable) - image = image_generator.generate_image(text, seed) - - save_image(image, image_path) - print(ascii_from_image(image, size=128)) + image = model.generate_image(text, seed) + save_image(image, image_path) + print(ascii_from_image(image, size=128)) if __name__ == '__main__': args = parser.parse_args() print(args) generate_image( - is_torch=args.torch, is_mega=args.mega, text=args.text, seed=args.seed, diff --git a/min_dalle/load_params.py b/min_dalle/load_params.py deleted file mode 100644 index 1ce2998..0000000 --- a/min_dalle/load_params.py +++ /dev/null @@ -1,136 +0,0 @@ -import os -import numpy -from typing import Dict -from flax.traverse_util import flatten_dict -from flax.serialization import msgpack_restore -import torch -torch.set_grad_enabled(False) - - -def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]: - with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f: - params: Dict[str, numpy.ndarray] = msgpack_restore(f.read()) - - P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.') - - for i in list(P.keys()): - j = i - if 'up' in i or 'down' in i: - j = i.replace('_', '.') - j = j.replace('proj.out', 'proj_out') - j = j.replace('nin.short', 'nin_short') - if 'bias' in i: - P[j] = P.pop(i) - elif 'scale' in i: - j = j.replace('scale', 'weight') - P[j] = P.pop(i) - elif 'kernel' in i: - j = j.replace('kernel', 'weight') - P[j] = P.pop(i).transpose(3, 2, 0, 1) - - for i in P: - P[i] = torch.tensor(P[i]) - - P['embedding.weight'] = P.pop('quantize.embedding.embedding') - - for i in list(P): - if i.split('.')[0] in ['encoder', 'quant_conv']: - P.pop(i) - - return P - - -def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]: - with open(os.path.join(path, "flax_model.msgpack"), "rb") as f: - params = msgpack_restore(f.read()) - - for codec in ['encoder', 'decoder']: - k = 'FlaxBart{}Layers'.format(codec.title()) - P: dict = params['model'][codec]['layers'][k] - P['pre_self_attn_layer_norm'] = P.pop('LayerNorm_0') - P['self_attn_layer_norm'] = P.pop('LayerNorm_1') - P['self_attn'] = P.pop('FlaxBartAttention_0') - if codec == 'decoder': - P['pre_encoder_attn_layer_norm'] = P.pop('LayerNorm_2') - P['encoder_attn_layer_norm'] = P.pop('LayerNorm_3') - P['encoder_attn'] = P.pop('FlaxBartAttention_1') - P['glu']: dict = P.pop('GLU_0') - P['glu']['ln0'] = P['glu'].pop('LayerNorm_0') - P['glu']['ln1'] = P['glu'].pop('LayerNorm_1') - P['glu']['fc0'] = P['glu'].pop('Dense_0') - P['glu']['fc1'] = P['glu'].pop('Dense_1') - P['glu']['fc2'] = P['glu'].pop('Dense_2') - - for codec in ['encoder', 'decoder']: - layers_params = params['model'][codec].pop('layers') - params['model'][codec] = { - **params['model'][codec], - **layers_params - } - - model_params = params.pop('model') - params = {**params, **model_params} - - params['decoder']['lm_head'] = params.pop('lm_head') - - return params - - -def convert_dalle_bart_torch_from_flax_params( - params: dict, - layer_count: int, - is_encoder: bool -) -> dict: - P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.') - - for i in P: - P[i] = torch.tensor(P[i]).to(torch.float16) - - for i in list(P): - if 'kernel' in i: - j = i.replace('kernel', 'weight') - P[j] = P.pop(i).transpose(-1, -2) - elif 'scale' in i: - j = i.replace('scale', 'weight') - P[j] = P.pop(i) - - for i in list(P): - j = 'FlaxBart{}Layers'.format('Encoder' if is_encoder else 'Decoder') - if j in i: - for l in range(layer_count): - k = i.replace(j, 'layers.' + str(l)) - P[k] = P[i][l] - P.pop(i) - - P['embed_tokens.weight'] = P.pop('embed_tokens.embedding') - P['embed_positions.weight'] = P.pop('embed_positions.embedding') - return P - - -def convert_and_save_torch_params(is_mega: bool, model_path: str): - print("converting params to torch") - layer_count = 24 if is_mega else 12 - flax_params = load_dalle_bart_flax_params(model_path) - encoder_params = convert_dalle_bart_torch_from_flax_params( - flax_params['encoder'], - layer_count=layer_count, - is_encoder=True - ) - decoder_params = convert_dalle_bart_torch_from_flax_params( - flax_params['decoder'], - layer_count=layer_count, - is_encoder=False - ) - - for i in decoder_params: - decoder_params[i] = decoder_params[i].to(torch.float16) - - for i in encoder_params: - encoder_params[i] = encoder_params[i].to(torch.float16) - - detoker_params = load_vqgan_torch_params('./pretrained/vqgan') - detoker_path = os.path.join('pretrained', 'vqgan', 'detoker.pt') - - torch.save(encoder_params, os.path.join(model_path, 'encoder.pt')) - torch.save(decoder_params, os.path.join(model_path, 'decoder.pt')) - torch.save(detoker_params, detoker_path) \ No newline at end of file diff --git a/min_dalle/min_dalle_base.py b/min_dalle/min_dalle_base.py deleted file mode 100644 index 2d0b217..0000000 --- a/min_dalle/min_dalle_base.py +++ /dev/null @@ -1,32 +0,0 @@ -import os -import json -import numpy - -from .text_tokenizer import TextTokenizer - -class MinDalleBase: - def __init__(self, is_mega: bool): - self.is_mega = is_mega - model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') - self.model_path = os.path.join('pretrained', model_name) - - print("reading files from {}".format(self.model_path)) - vocab_path = os.path.join(self.model_path, 'vocab.json') - merges_path = os.path.join(self.model_path, 'merges.txt') - - with open(vocab_path, 'r', encoding='utf8') as f: - vocab = json.load(f) - with open(merges_path, 'r', encoding='utf8') as f: - merges = f.read().split("\n")[1:-1] - - self.tokenizer = TextTokenizer(vocab, merges) - - - def tokenize_text(self, text: str) -> numpy.ndarray: - print("tokenizing text") - tokens = self.tokenizer.tokenize(text) - print("text tokens", tokens) - text_tokens = numpy.ones((2, 64), dtype=numpy.int32) - text_tokens[0, :2] = [tokens[0], tokens[-1]] - text_tokens[1, :len(tokens)] = tokens - return text_tokens \ No newline at end of file diff --git a/min_dalle/min_dalle_flax.py b/min_dalle/min_dalle_flax.py deleted file mode 100644 index 9ee53f8..0000000 --- a/min_dalle/min_dalle_flax.py +++ /dev/null @@ -1,87 +0,0 @@ -import jax -import numpy -from PIL import Image -import torch - -from .min_dalle_base import MinDalleBase -from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax -from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax -from .models.vqgan_detokenizer import VQGanDetokenizer - -from .load_params import load_dalle_bart_flax_params, load_vqgan_torch_params - - -class MinDalleFlax(MinDalleBase): - def __init__(self, is_mega: bool, is_reusable: bool = True): - super().__init__(is_mega) - self.is_reusable = is_reusable - print("initializing MinDalleFlax") - self.model_params = load_dalle_bart_flax_params(self.model_path) - if is_reusable: - self.init_encoder() - self.init_decoder() - self.init_detokenizer() - - - def init_encoder(self): - print("initializing DalleBartEncoderFlax") - self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax( - attention_head_count = 32 if self.is_mega else 16, - embed_count = 2048 if self.is_mega else 1024, - glu_embed_count = 4096 if self.is_mega else 2730, - text_token_count = 64, - text_vocab_count = 50272 if self.is_mega else 50264, - layer_count = 24 if self.is_mega else 12 - ).bind({'params': self.model_params.pop('encoder')}) - - - def init_decoder(self): - print("initializing DalleBartDecoderFlax") - self.decoder = DalleBartDecoderFlax( - image_token_count = 256, - image_vocab_count = 16415 if self.is_mega else 16384, - attention_head_count = 32 if self.is_mega else 16, - embed_count = 2048 if self.is_mega else 1024, - glu_embed_count = 4096 if self.is_mega else 2730, - layer_count = 24 if self.is_mega else 12, - start_token = 16415 if self.is_mega else 16384 - ) - - - def init_detokenizer(self): - print("initializing VQGanDetokenizer") - params = load_vqgan_torch_params('./pretrained/vqgan') - self.detokenizer = VQGanDetokenizer() - self.detokenizer.load_state_dict(params) - del params - - def generate_image(self, text: str, seed: int) -> Image.Image: - text_tokens = self.tokenize_text(text) - - if not self.is_reusable: self.init_encoder() - print("encoding text tokens") - encoder_state = self.encoder(text_tokens) - if not self.is_reusable: del self.encoder - - if not self.is_reusable: - self.init_decoder() - params = self.model_params.pop('decoder') - else: - params = self.model_params['decoder'] - print("sampling image tokens") - image_tokens = self.decoder.sample_image_tokens( - text_tokens, - encoder_state, - jax.random.PRNGKey(seed), - params - ) - if not self.is_reusable: del self.decoder - - image_tokens = torch.tensor(numpy.array(image_tokens)) - - if not self.is_reusable: self.init_detokenizer() - print("detokenizing image") - image = self.detokenizer.forward(image_tokens).to(torch.uint8) - if not self.is_reusable: del self.detokenizer - image = Image.fromarray(image.to('cpu').detach().numpy()) - return image \ No newline at end of file diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 936197e..f6c901c 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -1,18 +1,20 @@ import os from PIL import Image from typing import Dict +import numpy from torch import LongTensor import torch +import json torch.set_grad_enabled(False) torch.set_num_threads(os.cpu_count()) -from .min_dalle_base import MinDalleBase +from .text_tokenizer import TextTokenizer from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch from .models.vqgan_detokenizer import VQGanDetokenizer -class MinDalleTorch(MinDalleBase): +class MinDalleTorch: def __init__( self, is_mega: bool, @@ -20,7 +22,20 @@ class MinDalleTorch(MinDalleBase): token_count: int = 256 ): print("initializing MinDalleTorch") - super().__init__(is_mega) + self.is_mega = is_mega + model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') + self.model_path = os.path.join('pretrained', model_name) + + print("reading files from {}".format(self.model_path)) + vocab_path = os.path.join(self.model_path, 'vocab.json') + merges_path = os.path.join(self.model_path, 'merges.txt') + + with open(vocab_path, 'r', encoding='utf8') as f: + vocab = json.load(f) + with open(merges_path, 'r', encoding='utf8') as f: + merges = f.read().split("\n")[1:-1] + + self.tokenizer = TextTokenizer(vocab, merges) self.is_reusable = is_reusable self.token_count = token_count @@ -76,7 +91,17 @@ class MinDalleTorch(MinDalleBase): self.detokenizer.load_state_dict(params) del params if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() - + + + def tokenize_text(self, text: str) -> numpy.ndarray: + print("tokenizing text") + tokens = self.tokenizer.tokenize(text) + print("text tokens", tokens) + text_tokens = numpy.ones((2, 64), dtype=numpy.int32) + text_tokens[0, :2] = [tokens[0], tokens[-1]] + text_tokens[1, :len(tokens)] = tokens + return text_tokens + def generate_image_tokens(self, text: str, seed: int) -> LongTensor: text_tokens = self.tokenize_text(text) diff --git a/min_dalle/models/dalle_bart_decoder_flax.py b/min_dalle/models/dalle_bart_decoder_flax.py deleted file mode 100644 index 4d68190..0000000 --- a/min_dalle/models/dalle_bart_decoder_flax.py +++ /dev/null @@ -1,247 +0,0 @@ -import jax, flax -from jax import lax, numpy as jnp -from flax import linen as nn -from typing import Tuple - -from .dalle_bart_encoder_flax import GLUFlax, AttentionFlax - - -class DecoderCrossAttentionFlax(AttentionFlax): - def __call__( - self, - decoder_state: jnp.ndarray, - encoder_state: jnp.ndarray, - attention_mask: jnp.ndarray, - ) -> jnp.ndarray: - keys = self.k_proj(encoder_state) - values = self.v_proj(encoder_state) - queries = self.q_proj(decoder_state) - return self.forward(keys, values, queries, attention_mask) - - -class DecoderSelfAttentionFlax(AttentionFlax): - def __call__( - self, - decoder_state: jnp.ndarray, - attention_state: jnp.ndarray, - attention_mask: jnp.ndarray, - state_index: tuple - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - keys = self.k_proj(decoder_state) - values = self.v_proj(decoder_state) - queries = self.q_proj(decoder_state) - - attention_state = lax.dynamic_update_slice( - attention_state, - jnp.concatenate([keys, values]).astype(jnp.float32), - state_index - ) - batch_count = decoder_state.shape[0] - keys = attention_state[:batch_count] - values = attention_state[batch_count:] - - decoder_state = self.forward( - keys, - values, - queries, - attention_mask - ).astype(decoder_state.dtype) - return decoder_state, attention_state - - -class DalleBartDecoderLayerFlax(nn.Module): - image_token_count: int - attention_head_count: int - embed_count: int - glu_embed_count: int - - def setup(self): - self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False) - self.self_attn = DecoderSelfAttentionFlax( - self.attention_head_count, - self.embed_count - ) - self.self_attn_layer_norm = nn.LayerNorm() - self.pre_encoder_attn_layer_norm = nn.LayerNorm(use_scale=False) - self.encoder_attn = DecoderCrossAttentionFlax( - self.attention_head_count, - self.embed_count, - ) - self.encoder_attn_layer_norm = nn.LayerNorm() - self.glu = GLUFlax(self.embed_count, self.glu_embed_count) - - @nn.compact - def __call__( - self, - decoder_state: jnp.ndarray, - encoder_state: jnp.ndarray, - attention_state: jnp.ndarray, - attention_mask: jnp.ndarray, - token_index: int - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - # Self Attention - residual = decoder_state - decoder_state = self.pre_self_attn_layer_norm(decoder_state) - self_attention_mask = jnp.tile( - jnp.arange(self.image_token_count) < token_index + 1, - (decoder_state.shape[0], 1) - ) - decoder_state, attention_state = self.self_attn( - decoder_state, - attention_state, - self_attention_mask, - (0, token_index, 0) - ) - decoder_state = self.self_attn_layer_norm(decoder_state) - decoder_state = residual + decoder_state - - # Cross Attention - residual = decoder_state - decoder_state = self.pre_encoder_attn_layer_norm(decoder_state) - decoder_state = self.encoder_attn( - decoder_state, - encoder_state, - attention_mask - ) - decoder_state = self.encoder_attn_layer_norm(decoder_state) - decoder_state = residual + decoder_state - - # Feed forward - residual = decoder_state - decoder_state = self.glu(decoder_state) - decoder_state = residual + decoder_state - - return decoder_state, attention_state - - -@flax.struct.dataclass -class SampleState: - prev_token: jnp.ndarray - prng_key: jnp.ndarray - attention_state: jnp.ndarray - -def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray: - return (1 - a) * logits[0, -1] + a * logits[1, -1] - -def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray: - top_logits, _ = lax.top_k(logits, k) - suppressed = -jnp.inf * jnp.ones_like(logits) - return lax.select(logits < top_logits[-1], suppressed, logits) - -class DalleBartDecoderFlax(nn.Module): - image_token_count: int - image_vocab_count: int - attention_head_count: int - embed_count: int - glu_embed_count: int - layer_count: int - start_token: int - - def setup(self): - self.embed_tokens = nn.Embed( - self.image_vocab_count + 1, - self.embed_count - ) - self.embed_positions = nn.Embed( - self.image_token_count, - self.embed_count - ) - self.layers = nn.scan( - DalleBartDecoderLayerFlax, - variable_axes = { "params": 0 }, - split_rngs = { "params": True }, - in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast), - out_axes = 0, - length=self.layer_count, - )( - self.image_token_count, - self.attention_head_count, - self.embed_count, - self.glu_embed_count, - name="FlaxBartDecoderLayers" - ) - self.layernorm_embedding = nn.LayerNorm() - self.final_ln = nn.LayerNorm(use_scale=False) - self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False) - - def __call__( - self, - encoder_state: jnp.ndarray, - attention_state: jnp.ndarray, - attention_mask: jnp.ndarray, - prev_token: int, - token_index: int - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - batch_count = encoder_state.shape[0] - ones = jnp.ones((batch_count, 1), dtype=jnp.int32) - decoder_state = self.embed_tokens(prev_token * ones) - decoder_state += self.embed_positions(token_index * ones) - decoder_state = self.layernorm_embedding(decoder_state) - decoder_state, attention_state = self.layers( - decoder_state, - encoder_state, - attention_state, - attention_mask, - token_index - ) - decoder_state = self.final_ln(decoder_state) - decoder_state = self.lm_head(decoder_state) - return decoder_state, attention_state - - def sample_image_tokens( - self, - text_tokens: jnp.ndarray, - encoder_state: jnp.ndarray, - prng_key: jax.random.PRNGKey, - params: dict - ) -> jnp.ndarray: - attention_mask = jnp.not_equal(text_tokens, 1) - - def sample_next_image_token( - state: SampleState, - token_index: int - ) -> Tuple[SampleState, jnp.ndarray]: - logits, attention_state = self.apply( - { 'params': params }, - encoder_state = encoder_state, - attention_state = state.attention_state, - attention_mask = attention_mask, - prev_token = state.prev_token, - token_index = token_index - ) - - logits = super_conditioned(logits, 10.0) - logits = keep_top_k(logits, k=50) - - prng_key, prng_key_next = jax.random.split(state.prng_key) - next_token = jax.random.categorical(prng_key, logits, axis=-1) - - state = SampleState( - prev_token = next_token, - prng_key = prng_key_next, - attention_state = attention_state - ) - - return state, next_token - - batch_count = encoder_state.shape[0] - attention_state_shape = ( - self.layer_count, - batch_count * 2, - self.image_token_count, - self.embed_count - ) - - initial_state = SampleState( - prev_token = self.start_token, - prng_key = prng_key, - attention_state = jnp.zeros(attention_state_shape) - ) - - _, image_tokens = lax.scan( - sample_next_image_token, - initial_state, - jnp.arange(self.image_token_count) - ) - - return image_tokens \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_encoder_flax.py b/min_dalle/models/dalle_bart_encoder_flax.py deleted file mode 100644 index 7a1cc1b..0000000 --- a/min_dalle/models/dalle_bart_encoder_flax.py +++ /dev/null @@ -1,151 +0,0 @@ -from functools import partial -import jax -from jax import lax, numpy as jnp -from flax import linen as nn - - -class GLUFlax(nn.Module): - count_in_out: int - count_middle: int - - def setup(self): - self.gelu = partial(nn.gelu, approximate=False) - self.ln0 = nn.LayerNorm(use_scale=False) - self.ln1 = nn.LayerNorm(use_scale=False) - self.fc0 = nn.Dense(self.count_middle, use_bias=False) - self.fc1 = nn.Dense(self.count_middle, use_bias=False) - self.fc2 = nn.Dense(self.count_in_out, use_bias=False) - - @nn.compact - def __call__(self, z: jnp.ndarray) -> jnp.ndarray: - z = self.ln0(z) - z = self.ln1(self.gelu(self.fc0(z)) * self.fc1(z)) - z = self.fc2(z) - return z - - -class AttentionFlax(nn.Module): - head_count: int - embed_count: int - - def setup(self): - self.q_proj = nn.Dense(self.embed_count, use_bias=False) - self.k_proj = nn.Dense(self.embed_count, use_bias=False) - self.v_proj = nn.Dense(self.embed_count, use_bias=False) - self.out_proj = nn.Dense(self.embed_count, use_bias=False) - - def forward( - self, - keys: jnp.ndarray, - values: jnp.ndarray, - queries: jnp.ndarray, - attention_mask: jnp.ndarray - ) -> jnp.ndarray: - keys = keys.reshape(keys.shape[:2] + (self.head_count, -1)) - values = values.reshape(values.shape[:2] + (self.head_count, -1)) - queries = queries.reshape(queries.shape[:2] + (self.head_count, -1)) - queries /= queries.shape[-1] ** 0.5 - attention_bias: jnp.ndarray = lax.select( - attention_mask, - jnp.full(attention_mask.shape, 0.0), - jnp.full(attention_mask.shape, -jnp.inf), - ) - attention_weights: jnp.ndarray = jnp.einsum( - 'bqhd,bkhd->bhqk', - queries, - keys - ) - attention_weights += attention_bias[:, None, None, :] - attention_weights = jax.nn.softmax(attention_weights) - attention_output: jnp.ndarray = jnp.einsum( - "bhqk,bkhd->bqhd", - attention_weights, - values - ) - shape = attention_output.shape[:2] + (self.embed_count,) - attention_output = attention_output.reshape(shape) - attention_output = self.out_proj(attention_output) - return attention_output - - -class EncoderSelfAttentionFlax(AttentionFlax): - def __call__( - self, - encoder_state: jnp.ndarray, - attention_mask: jnp.ndarray - ) -> jnp.ndarray: - keys = self.k_proj(encoder_state) - values = self.v_proj(encoder_state) - queries = self.q_proj(encoder_state) - return self.forward(keys, values, queries, attention_mask) - - -class DalleBartEncoderLayerFlax(nn.Module): - attention_head_count: int - embed_count: int - glu_embed_count: int - - def setup(self): - self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False) - self.self_attn = EncoderSelfAttentionFlax( - self.attention_head_count, - self.embed_count - ) - self.self_attn_layer_norm = nn.LayerNorm() - self.glu = GLUFlax(self.embed_count, self.glu_embed_count) - - @nn.compact - def __call__( - self, - encoder_state: jnp.ndarray, - attention_mask: jnp.ndarray - ) -> jnp.ndarray: - residual = encoder_state - encoder_state = self.pre_self_attn_layer_norm(encoder_state) - encoder_state = self.self_attn(encoder_state, attention_mask) - encoder_state = self.self_attn_layer_norm(encoder_state) - encoder_state = residual + encoder_state - residual = encoder_state - encoder_state = self.glu(encoder_state) - encoder_state = residual + encoder_state - return encoder_state, None - - -class DalleBartEncoderFlax(nn.Module): - attention_head_count: int - embed_count: int - glu_embed_count: int - text_token_count: int - text_vocab_count: int - layer_count: int - - def setup(self): - self.embed_tokens = nn.Embed(self.text_vocab_count, self.embed_count) - self.embed_positions = nn.Embed(self.text_token_count, self.embed_count) - self.layers = nn.scan( - DalleBartEncoderLayerFlax, - variable_axes = { "params": 0 }, - split_rngs = { "params": True }, - in_axes = nn.broadcast, - length = self.layer_count - )( - self.attention_head_count, - self.embed_count, - self.glu_embed_count, - name="FlaxBartEncoderLayers" - ) - self.layernorm_embedding = nn.LayerNorm() - self.final_ln = nn.LayerNorm(use_scale=False) - - def __call__(self, text_tokens: jnp.ndarray) -> jnp.ndarray: - batch_count, token_count = text_tokens.shape - pose_tokens = jnp.tile(jnp.arange(token_count), (batch_count, 1)) - attention_mask = jnp.not_equal(text_tokens, 1) - encoder_state = ( - self.embed_tokens(text_tokens) + - self.embed_positions(pose_tokens) - ) - encoder_state = self.layernorm_embedding(encoder_state) - encoder_state, _ = self.layers(encoder_state, attention_mask) - encoder_state = self.final_ln(encoder_state) - return encoder_state \ No newline at end of file diff --git a/cog.yaml b/replicate/cog.yaml similarity index 100% rename from cog.yaml rename to replicate/cog.yaml diff --git a/predict.py b/replicate/predict.py similarity index 100% rename from predict.py rename to replicate/predict.py diff --git a/requirements_flax.txt b/requirements_flax.txt deleted file mode 100644 index ee8ff54..0000000 --- a/requirements_flax.txt +++ /dev/null @@ -1,3 +0,0 @@ -flax -torch -wandb \ No newline at end of file diff --git a/setup.sh b/setup.sh index 9596260..d221a45 100644 --- a/setup.sh +++ b/setup.sh @@ -4,17 +4,22 @@ set -e pip3 install -r requirements.txt -mkdir -p ./pretrained/dalle_bart_mega/ -curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab.json -L --output ./pretrained/dalle_bart_mega/vocab.json -curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges.txt -L --output ./pretrained/dalle_bart_mega/merges.txt -curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder.pt -L --output ./pretrained/dalle_bart_mega/encoder.pt -curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder.pt -L --output ./pretrained/dalle_bart_mega/decoder.pt +repo_path="https://huggingface.co/kuprel/min-dalle/resolve/main" -mkdir -p ./pretrained/dalle_bart_mini/ -curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab_mini.json -L --output ./pretrained/dalle_bart_mini/vocab.json -curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges_mini.txt -L --output ./pretrained/dalle_bart_mini/merges.txt -curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder_mini.pt -L --output ./pretrained/dalle_bart_mini/encoder.pt -curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder_mini.pt -L --output ./pretrained/dalle_bart_mini/decoder.pt +mini_path="./pretrained/dalle_bart_mini" +mega_path="./pretrained/dalle_bart_mega" +vqgan_path="./pretrained/vqgan" -mkdir -p ./pretrained/vqgan/ -curl https://huggingface.co/kuprel/min-dalle/resolve/main/detoker.pt -L --output ./pretrained/vqgan/detoker.pt \ No newline at end of file +mkdir -p ${vqgan_path} +mkdir -p ${mini_path} +mkdir -p ${mega_path} + +curl ${repo_path}/detoker.pt -L --output ${vqgan_path}/detoker.pt +curl ${repo_path}/vocab_mini.json -L --output ${mini_path}/vocab.json +curl ${repo_path}/merges_mini.txt -L --output ${mini_path}/merges.txt +curl ${repo_path}/encoder_mini.pt -L --output ${mini_path}/encoder.pt +curl ${repo_path}/decoder_mini.pt -L --output ${mini_path}/decoder.pt +curl ${repo_path}/vocab.json -L --output ${mega_path}/vocab.json +curl ${repo_path}/merges.txt -L --output ${mega_path}/merges.txt +curl ${repo_path}/encoder.pt -L --output ${mega_path}/encoder.pt +curl ${repo_path}/decoder.pt -L --output ${mega_path}/decoder.pt \ No newline at end of file diff --git a/setup_flax.sh b/setup_flax.sh deleted file mode 100644 index ff5fa84..0000000 --- a/setup_flax.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -set -e - -pip3 install -r requirements_flax.txt - -# download vqgan -mkdir -p pretrained/vqgan -curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output ./pretrained/vqgan/flax_model.msgpack - -# download dalle-mini and dalle-mega -python3 -m wandb login --anonymously -python3 -m wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0 -python3 -m wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14