moved flax model and conversion code to separate repository
This commit is contained in:
parent
febd18df77
commit
07ce93d5f8
10
README.md
vendored
10
README.md
vendored
|
@ -3,21 +3,19 @@
|
||||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
||||||
[![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
|
[![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
|
||||||
|
|
||||||
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch.
|
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini) in PyTorch. It has been stripped to the bare essentials necessary for doing inference. The only third party dependencies are numpy and torch.
|
||||||
|
|
||||||
It currently takes **7.4 seconds** to generate an image with DALL·E Mega with PyTorch on a standard GPU runtime in Colab
|
It currently takes **7.4 seconds** to generate an image with DALL·E Mega with PyTorch on a standard GPU runtime in Colab
|
||||||
|
|
||||||
|
The flax model, and the code for coverting it to torch, has been moved [here](https://github.com/kuprel/min-dalle-flax).
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
Run `sh setup.sh` to install dependencies and download pretrained models. The torch models can be manually downloaded [here](https://huggingface.co/kuprel/min-dalle/tree/main).
|
Run `sh setup.sh` to install dependencies and download pretrained models. The torch models can be manually downloaded [here](https://huggingface.co/kuprel/min-dalle/tree/main).
|
||||||
The flax models can be manually downloaded here:
|
|
||||||
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
|
||||||
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
|
||||||
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize either `MinDalleTorch` or `MinDalleFlax`, then call `generate_image` with some text and a seed. See the colab for an example.
|
Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize `MinDalleTorch`, then call `generate_image` with some text and a seed. See the colab for an example.
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
|
|
|
@ -3,15 +3,11 @@ import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from min_dalle.min_dalle_torch import MinDalleTorch
|
from min_dalle.min_dalle_torch import MinDalleTorch
|
||||||
from min_dalle.min_dalle_flax import MinDalleFlax
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--mega', action='store_true')
|
parser.add_argument('--mega', action='store_true')
|
||||||
parser.add_argument('--no-mega', dest='mega', action='store_false')
|
parser.add_argument('--no-mega', dest='mega', action='store_false')
|
||||||
parser.set_defaults(mega=False)
|
parser.set_defaults(mega=False)
|
||||||
parser.add_argument('--torch', action='store_true')
|
|
||||||
parser.add_argument('--no-torch', dest='torch', action='store_false')
|
|
||||||
parser.set_defaults(torch=True)
|
|
||||||
parser.add_argument('--text', type=str, default='cat')
|
parser.add_argument('--text', type=str, default='cat')
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--image_path', type=str, default='generated')
|
parser.add_argument('--image_path', type=str, default='generated')
|
||||||
|
@ -37,7 +33,6 @@ def save_image(image: Image.Image, path: str):
|
||||||
|
|
||||||
|
|
||||||
def generate_image(
|
def generate_image(
|
||||||
is_torch: bool,
|
|
||||||
is_mega: bool,
|
is_mega: bool,
|
||||||
text: str,
|
text: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
@ -45,29 +40,21 @@ def generate_image(
|
||||||
token_count: int
|
token_count: int
|
||||||
):
|
):
|
||||||
is_reusable = False
|
is_reusable = False
|
||||||
if is_torch:
|
model = MinDalleTorch(is_mega, is_reusable, token_count)
|
||||||
image_generator = MinDalleTorch(is_mega, is_reusable, token_count)
|
|
||||||
|
|
||||||
if token_count < 256:
|
|
||||||
image_tokens = image_generator.generate_image_tokens(text, seed)
|
|
||||||
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
image = image_generator.generate_image(text, seed)
|
|
||||||
|
|
||||||
|
if token_count < 256:
|
||||||
|
image_tokens = model.generate_image_tokens(text, seed)
|
||||||
|
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
|
||||||
else:
|
else:
|
||||||
image_generator = MinDalleFlax(is_mega, is_reusable)
|
image = model.generate_image(text, seed)
|
||||||
image = image_generator.generate_image(text, seed)
|
save_image(image, image_path)
|
||||||
|
print(ascii_from_image(image, size=128))
|
||||||
save_image(image, image_path)
|
|
||||||
print(ascii_from_image(image, size=128))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
generate_image(
|
generate_image(
|
||||||
is_torch=args.torch,
|
|
||||||
is_mega=args.mega,
|
is_mega=args.mega,
|
||||||
text=args.text,
|
text=args.text,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
|
|
|
@ -1,136 +0,0 @@
|
||||||
import os
|
|
||||||
import numpy
|
|
||||||
from typing import Dict
|
|
||||||
from flax.traverse_util import flatten_dict
|
|
||||||
from flax.serialization import msgpack_restore
|
|
||||||
import torch
|
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
|
|
||||||
|
|
||||||
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
|
||||||
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
|
|
||||||
params: Dict[str, numpy.ndarray] = msgpack_restore(f.read())
|
|
||||||
|
|
||||||
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
|
|
||||||
|
|
||||||
for i in list(P.keys()):
|
|
||||||
j = i
|
|
||||||
if 'up' in i or 'down' in i:
|
|
||||||
j = i.replace('_', '.')
|
|
||||||
j = j.replace('proj.out', 'proj_out')
|
|
||||||
j = j.replace('nin.short', 'nin_short')
|
|
||||||
if 'bias' in i:
|
|
||||||
P[j] = P.pop(i)
|
|
||||||
elif 'scale' in i:
|
|
||||||
j = j.replace('scale', 'weight')
|
|
||||||
P[j] = P.pop(i)
|
|
||||||
elif 'kernel' in i:
|
|
||||||
j = j.replace('kernel', 'weight')
|
|
||||||
P[j] = P.pop(i).transpose(3, 2, 0, 1)
|
|
||||||
|
|
||||||
for i in P:
|
|
||||||
P[i] = torch.tensor(P[i])
|
|
||||||
|
|
||||||
P['embedding.weight'] = P.pop('quantize.embedding.embedding')
|
|
||||||
|
|
||||||
for i in list(P):
|
|
||||||
if i.split('.')[0] in ['encoder', 'quant_conv']:
|
|
||||||
P.pop(i)
|
|
||||||
|
|
||||||
return P
|
|
||||||
|
|
||||||
|
|
||||||
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
|
|
||||||
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
|
|
||||||
params = msgpack_restore(f.read())
|
|
||||||
|
|
||||||
for codec in ['encoder', 'decoder']:
|
|
||||||
k = 'FlaxBart{}Layers'.format(codec.title())
|
|
||||||
P: dict = params['model'][codec]['layers'][k]
|
|
||||||
P['pre_self_attn_layer_norm'] = P.pop('LayerNorm_0')
|
|
||||||
P['self_attn_layer_norm'] = P.pop('LayerNorm_1')
|
|
||||||
P['self_attn'] = P.pop('FlaxBartAttention_0')
|
|
||||||
if codec == 'decoder':
|
|
||||||
P['pre_encoder_attn_layer_norm'] = P.pop('LayerNorm_2')
|
|
||||||
P['encoder_attn_layer_norm'] = P.pop('LayerNorm_3')
|
|
||||||
P['encoder_attn'] = P.pop('FlaxBartAttention_1')
|
|
||||||
P['glu']: dict = P.pop('GLU_0')
|
|
||||||
P['glu']['ln0'] = P['glu'].pop('LayerNorm_0')
|
|
||||||
P['glu']['ln1'] = P['glu'].pop('LayerNorm_1')
|
|
||||||
P['glu']['fc0'] = P['glu'].pop('Dense_0')
|
|
||||||
P['glu']['fc1'] = P['glu'].pop('Dense_1')
|
|
||||||
P['glu']['fc2'] = P['glu'].pop('Dense_2')
|
|
||||||
|
|
||||||
for codec in ['encoder', 'decoder']:
|
|
||||||
layers_params = params['model'][codec].pop('layers')
|
|
||||||
params['model'][codec] = {
|
|
||||||
**params['model'][codec],
|
|
||||||
**layers_params
|
|
||||||
}
|
|
||||||
|
|
||||||
model_params = params.pop('model')
|
|
||||||
params = {**params, **model_params}
|
|
||||||
|
|
||||||
params['decoder']['lm_head'] = params.pop('lm_head')
|
|
||||||
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def convert_dalle_bart_torch_from_flax_params(
|
|
||||||
params: dict,
|
|
||||||
layer_count: int,
|
|
||||||
is_encoder: bool
|
|
||||||
) -> dict:
|
|
||||||
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
|
|
||||||
|
|
||||||
for i in P:
|
|
||||||
P[i] = torch.tensor(P[i]).to(torch.float16)
|
|
||||||
|
|
||||||
for i in list(P):
|
|
||||||
if 'kernel' in i:
|
|
||||||
j = i.replace('kernel', 'weight')
|
|
||||||
P[j] = P.pop(i).transpose(-1, -2)
|
|
||||||
elif 'scale' in i:
|
|
||||||
j = i.replace('scale', 'weight')
|
|
||||||
P[j] = P.pop(i)
|
|
||||||
|
|
||||||
for i in list(P):
|
|
||||||
j = 'FlaxBart{}Layers'.format('Encoder' if is_encoder else 'Decoder')
|
|
||||||
if j in i:
|
|
||||||
for l in range(layer_count):
|
|
||||||
k = i.replace(j, 'layers.' + str(l))
|
|
||||||
P[k] = P[i][l]
|
|
||||||
P.pop(i)
|
|
||||||
|
|
||||||
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
|
|
||||||
P['embed_positions.weight'] = P.pop('embed_positions.embedding')
|
|
||||||
return P
|
|
||||||
|
|
||||||
|
|
||||||
def convert_and_save_torch_params(is_mega: bool, model_path: str):
|
|
||||||
print("converting params to torch")
|
|
||||||
layer_count = 24 if is_mega else 12
|
|
||||||
flax_params = load_dalle_bart_flax_params(model_path)
|
|
||||||
encoder_params = convert_dalle_bart_torch_from_flax_params(
|
|
||||||
flax_params['encoder'],
|
|
||||||
layer_count=layer_count,
|
|
||||||
is_encoder=True
|
|
||||||
)
|
|
||||||
decoder_params = convert_dalle_bart_torch_from_flax_params(
|
|
||||||
flax_params['decoder'],
|
|
||||||
layer_count=layer_count,
|
|
||||||
is_encoder=False
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in decoder_params:
|
|
||||||
decoder_params[i] = decoder_params[i].to(torch.float16)
|
|
||||||
|
|
||||||
for i in encoder_params:
|
|
||||||
encoder_params[i] = encoder_params[i].to(torch.float16)
|
|
||||||
|
|
||||||
detoker_params = load_vqgan_torch_params('./pretrained/vqgan')
|
|
||||||
detoker_path = os.path.join('pretrained', 'vqgan', 'detoker.pt')
|
|
||||||
|
|
||||||
torch.save(encoder_params, os.path.join(model_path, 'encoder.pt'))
|
|
||||||
torch.save(decoder_params, os.path.join(model_path, 'decoder.pt'))
|
|
||||||
torch.save(detoker_params, detoker_path)
|
|
|
@ -1,32 +0,0 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import numpy
|
|
||||||
|
|
||||||
from .text_tokenizer import TextTokenizer
|
|
||||||
|
|
||||||
class MinDalleBase:
|
|
||||||
def __init__(self, is_mega: bool):
|
|
||||||
self.is_mega = is_mega
|
|
||||||
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
|
||||||
self.model_path = os.path.join('pretrained', model_name)
|
|
||||||
|
|
||||||
print("reading files from {}".format(self.model_path))
|
|
||||||
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
|
||||||
merges_path = os.path.join(self.model_path, 'merges.txt')
|
|
||||||
|
|
||||||
with open(vocab_path, 'r', encoding='utf8') as f:
|
|
||||||
vocab = json.load(f)
|
|
||||||
with open(merges_path, 'r', encoding='utf8') as f:
|
|
||||||
merges = f.read().split("\n")[1:-1]
|
|
||||||
|
|
||||||
self.tokenizer = TextTokenizer(vocab, merges)
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_text(self, text: str) -> numpy.ndarray:
|
|
||||||
print("tokenizing text")
|
|
||||||
tokens = self.tokenizer.tokenize(text)
|
|
||||||
print("text tokens", tokens)
|
|
||||||
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
|
|
||||||
text_tokens[0, :2] = [tokens[0], tokens[-1]]
|
|
||||||
text_tokens[1, :len(tokens)] = tokens
|
|
||||||
return text_tokens
|
|
|
@ -1,87 +0,0 @@
|
||||||
import jax
|
|
||||||
import numpy
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .min_dalle_base import MinDalleBase
|
|
||||||
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
|
||||||
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
|
||||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
|
||||||
|
|
||||||
from .load_params import load_dalle_bart_flax_params, load_vqgan_torch_params
|
|
||||||
|
|
||||||
|
|
||||||
class MinDalleFlax(MinDalleBase):
|
|
||||||
def __init__(self, is_mega: bool, is_reusable: bool = True):
|
|
||||||
super().__init__(is_mega)
|
|
||||||
self.is_reusable = is_reusable
|
|
||||||
print("initializing MinDalleFlax")
|
|
||||||
self.model_params = load_dalle_bart_flax_params(self.model_path)
|
|
||||||
if is_reusable:
|
|
||||||
self.init_encoder()
|
|
||||||
self.init_decoder()
|
|
||||||
self.init_detokenizer()
|
|
||||||
|
|
||||||
|
|
||||||
def init_encoder(self):
|
|
||||||
print("initializing DalleBartEncoderFlax")
|
|
||||||
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
|
|
||||||
attention_head_count = 32 if self.is_mega else 16,
|
|
||||||
embed_count = 2048 if self.is_mega else 1024,
|
|
||||||
glu_embed_count = 4096 if self.is_mega else 2730,
|
|
||||||
text_token_count = 64,
|
|
||||||
text_vocab_count = 50272 if self.is_mega else 50264,
|
|
||||||
layer_count = 24 if self.is_mega else 12
|
|
||||||
).bind({'params': self.model_params.pop('encoder')})
|
|
||||||
|
|
||||||
|
|
||||||
def init_decoder(self):
|
|
||||||
print("initializing DalleBartDecoderFlax")
|
|
||||||
self.decoder = DalleBartDecoderFlax(
|
|
||||||
image_token_count = 256,
|
|
||||||
image_vocab_count = 16415 if self.is_mega else 16384,
|
|
||||||
attention_head_count = 32 if self.is_mega else 16,
|
|
||||||
embed_count = 2048 if self.is_mega else 1024,
|
|
||||||
glu_embed_count = 4096 if self.is_mega else 2730,
|
|
||||||
layer_count = 24 if self.is_mega else 12,
|
|
||||||
start_token = 16415 if self.is_mega else 16384
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def init_detokenizer(self):
|
|
||||||
print("initializing VQGanDetokenizer")
|
|
||||||
params = load_vqgan_torch_params('./pretrained/vqgan')
|
|
||||||
self.detokenizer = VQGanDetokenizer()
|
|
||||||
self.detokenizer.load_state_dict(params)
|
|
||||||
del params
|
|
||||||
|
|
||||||
def generate_image(self, text: str, seed: int) -> Image.Image:
|
|
||||||
text_tokens = self.tokenize_text(text)
|
|
||||||
|
|
||||||
if not self.is_reusable: self.init_encoder()
|
|
||||||
print("encoding text tokens")
|
|
||||||
encoder_state = self.encoder(text_tokens)
|
|
||||||
if not self.is_reusable: del self.encoder
|
|
||||||
|
|
||||||
if not self.is_reusable:
|
|
||||||
self.init_decoder()
|
|
||||||
params = self.model_params.pop('decoder')
|
|
||||||
else:
|
|
||||||
params = self.model_params['decoder']
|
|
||||||
print("sampling image tokens")
|
|
||||||
image_tokens = self.decoder.sample_image_tokens(
|
|
||||||
text_tokens,
|
|
||||||
encoder_state,
|
|
||||||
jax.random.PRNGKey(seed),
|
|
||||||
params
|
|
||||||
)
|
|
||||||
if not self.is_reusable: del self.decoder
|
|
||||||
|
|
||||||
image_tokens = torch.tensor(numpy.array(image_tokens))
|
|
||||||
|
|
||||||
if not self.is_reusable: self.init_detokenizer()
|
|
||||||
print("detokenizing image")
|
|
||||||
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
|
|
||||||
if not self.is_reusable: del self.detokenizer
|
|
||||||
image = Image.fromarray(image.to('cpu').detach().numpy())
|
|
||||||
return image
|
|
|
@ -1,18 +1,20 @@
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
import numpy
|
||||||
from torch import LongTensor
|
from torch import LongTensor
|
||||||
import torch
|
import torch
|
||||||
|
import json
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_num_threads(os.cpu_count())
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
|
||||||
from .min_dalle_base import MinDalleBase
|
from .text_tokenizer import TextTokenizer
|
||||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||||
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||||
|
|
||||||
|
|
||||||
class MinDalleTorch(MinDalleBase):
|
class MinDalleTorch:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
is_mega: bool,
|
is_mega: bool,
|
||||||
|
@ -20,7 +22,20 @@ class MinDalleTorch(MinDalleBase):
|
||||||
token_count: int = 256
|
token_count: int = 256
|
||||||
):
|
):
|
||||||
print("initializing MinDalleTorch")
|
print("initializing MinDalleTorch")
|
||||||
super().__init__(is_mega)
|
self.is_mega = is_mega
|
||||||
|
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
||||||
|
self.model_path = os.path.join('pretrained', model_name)
|
||||||
|
|
||||||
|
print("reading files from {}".format(self.model_path))
|
||||||
|
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
||||||
|
merges_path = os.path.join(self.model_path, 'merges.txt')
|
||||||
|
|
||||||
|
with open(vocab_path, 'r', encoding='utf8') as f:
|
||||||
|
vocab = json.load(f)
|
||||||
|
with open(merges_path, 'r', encoding='utf8') as f:
|
||||||
|
merges = f.read().split("\n")[1:-1]
|
||||||
|
|
||||||
|
self.tokenizer = TextTokenizer(vocab, merges)
|
||||||
self.is_reusable = is_reusable
|
self.is_reusable = is_reusable
|
||||||
self.token_count = token_count
|
self.token_count = token_count
|
||||||
|
|
||||||
|
@ -78,6 +93,16 @@ class MinDalleTorch(MinDalleBase):
|
||||||
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
|
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_text(self, text: str) -> numpy.ndarray:
|
||||||
|
print("tokenizing text")
|
||||||
|
tokens = self.tokenizer.tokenize(text)
|
||||||
|
print("text tokens", tokens)
|
||||||
|
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
|
||||||
|
text_tokens[0, :2] = [tokens[0], tokens[-1]]
|
||||||
|
text_tokens[1, :len(tokens)] = tokens
|
||||||
|
return text_tokens
|
||||||
|
|
||||||
|
|
||||||
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
||||||
text_tokens = self.tokenize_text(text)
|
text_tokens = self.tokenize_text(text)
|
||||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||||
|
|
|
@ -1,247 +0,0 @@
|
||||||
import jax, flax
|
|
||||||
from jax import lax, numpy as jnp
|
|
||||||
from flax import linen as nn
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from .dalle_bart_encoder_flax import GLUFlax, AttentionFlax
|
|
||||||
|
|
||||||
|
|
||||||
class DecoderCrossAttentionFlax(AttentionFlax):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
decoder_state: jnp.ndarray,
|
|
||||||
encoder_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray,
|
|
||||||
) -> jnp.ndarray:
|
|
||||||
keys = self.k_proj(encoder_state)
|
|
||||||
values = self.v_proj(encoder_state)
|
|
||||||
queries = self.q_proj(decoder_state)
|
|
||||||
return self.forward(keys, values, queries, attention_mask)
|
|
||||||
|
|
||||||
|
|
||||||
class DecoderSelfAttentionFlax(AttentionFlax):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
decoder_state: jnp.ndarray,
|
|
||||||
attention_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray,
|
|
||||||
state_index: tuple
|
|
||||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
||||||
keys = self.k_proj(decoder_state)
|
|
||||||
values = self.v_proj(decoder_state)
|
|
||||||
queries = self.q_proj(decoder_state)
|
|
||||||
|
|
||||||
attention_state = lax.dynamic_update_slice(
|
|
||||||
attention_state,
|
|
||||||
jnp.concatenate([keys, values]).astype(jnp.float32),
|
|
||||||
state_index
|
|
||||||
)
|
|
||||||
batch_count = decoder_state.shape[0]
|
|
||||||
keys = attention_state[:batch_count]
|
|
||||||
values = attention_state[batch_count:]
|
|
||||||
|
|
||||||
decoder_state = self.forward(
|
|
||||||
keys,
|
|
||||||
values,
|
|
||||||
queries,
|
|
||||||
attention_mask
|
|
||||||
).astype(decoder_state.dtype)
|
|
||||||
return decoder_state, attention_state
|
|
||||||
|
|
||||||
|
|
||||||
class DalleBartDecoderLayerFlax(nn.Module):
|
|
||||||
image_token_count: int
|
|
||||||
attention_head_count: int
|
|
||||||
embed_count: int
|
|
||||||
glu_embed_count: int
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False)
|
|
||||||
self.self_attn = DecoderSelfAttentionFlax(
|
|
||||||
self.attention_head_count,
|
|
||||||
self.embed_count
|
|
||||||
)
|
|
||||||
self.self_attn_layer_norm = nn.LayerNorm()
|
|
||||||
self.pre_encoder_attn_layer_norm = nn.LayerNorm(use_scale=False)
|
|
||||||
self.encoder_attn = DecoderCrossAttentionFlax(
|
|
||||||
self.attention_head_count,
|
|
||||||
self.embed_count,
|
|
||||||
)
|
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm()
|
|
||||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
|
||||||
|
|
||||||
@nn.compact
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
decoder_state: jnp.ndarray,
|
|
||||||
encoder_state: jnp.ndarray,
|
|
||||||
attention_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray,
|
|
||||||
token_index: int
|
|
||||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
||||||
# Self Attention
|
|
||||||
residual = decoder_state
|
|
||||||
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
|
|
||||||
self_attention_mask = jnp.tile(
|
|
||||||
jnp.arange(self.image_token_count) < token_index + 1,
|
|
||||||
(decoder_state.shape[0], 1)
|
|
||||||
)
|
|
||||||
decoder_state, attention_state = self.self_attn(
|
|
||||||
decoder_state,
|
|
||||||
attention_state,
|
|
||||||
self_attention_mask,
|
|
||||||
(0, token_index, 0)
|
|
||||||
)
|
|
||||||
decoder_state = self.self_attn_layer_norm(decoder_state)
|
|
||||||
decoder_state = residual + decoder_state
|
|
||||||
|
|
||||||
# Cross Attention
|
|
||||||
residual = decoder_state
|
|
||||||
decoder_state = self.pre_encoder_attn_layer_norm(decoder_state)
|
|
||||||
decoder_state = self.encoder_attn(
|
|
||||||
decoder_state,
|
|
||||||
encoder_state,
|
|
||||||
attention_mask
|
|
||||||
)
|
|
||||||
decoder_state = self.encoder_attn_layer_norm(decoder_state)
|
|
||||||
decoder_state = residual + decoder_state
|
|
||||||
|
|
||||||
# Feed forward
|
|
||||||
residual = decoder_state
|
|
||||||
decoder_state = self.glu(decoder_state)
|
|
||||||
decoder_state = residual + decoder_state
|
|
||||||
|
|
||||||
return decoder_state, attention_state
|
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
|
||||||
class SampleState:
|
|
||||||
prev_token: jnp.ndarray
|
|
||||||
prng_key: jnp.ndarray
|
|
||||||
attention_state: jnp.ndarray
|
|
||||||
|
|
||||||
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
|
||||||
return (1 - a) * logits[0, -1] + a * logits[1, -1]
|
|
||||||
|
|
||||||
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
|
|
||||||
top_logits, _ = lax.top_k(logits, k)
|
|
||||||
suppressed = -jnp.inf * jnp.ones_like(logits)
|
|
||||||
return lax.select(logits < top_logits[-1], suppressed, logits)
|
|
||||||
|
|
||||||
class DalleBartDecoderFlax(nn.Module):
|
|
||||||
image_token_count: int
|
|
||||||
image_vocab_count: int
|
|
||||||
attention_head_count: int
|
|
||||||
embed_count: int
|
|
||||||
glu_embed_count: int
|
|
||||||
layer_count: int
|
|
||||||
start_token: int
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.embed_tokens = nn.Embed(
|
|
||||||
self.image_vocab_count + 1,
|
|
||||||
self.embed_count
|
|
||||||
)
|
|
||||||
self.embed_positions = nn.Embed(
|
|
||||||
self.image_token_count,
|
|
||||||
self.embed_count
|
|
||||||
)
|
|
||||||
self.layers = nn.scan(
|
|
||||||
DalleBartDecoderLayerFlax,
|
|
||||||
variable_axes = { "params": 0 },
|
|
||||||
split_rngs = { "params": True },
|
|
||||||
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
|
|
||||||
out_axes = 0,
|
|
||||||
length=self.layer_count,
|
|
||||||
)(
|
|
||||||
self.image_token_count,
|
|
||||||
self.attention_head_count,
|
|
||||||
self.embed_count,
|
|
||||||
self.glu_embed_count,
|
|
||||||
name="FlaxBartDecoderLayers"
|
|
||||||
)
|
|
||||||
self.layernorm_embedding = nn.LayerNorm()
|
|
||||||
self.final_ln = nn.LayerNorm(use_scale=False)
|
|
||||||
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
encoder_state: jnp.ndarray,
|
|
||||||
attention_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray,
|
|
||||||
prev_token: int,
|
|
||||||
token_index: int
|
|
||||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
||||||
batch_count = encoder_state.shape[0]
|
|
||||||
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
|
|
||||||
decoder_state = self.embed_tokens(prev_token * ones)
|
|
||||||
decoder_state += self.embed_positions(token_index * ones)
|
|
||||||
decoder_state = self.layernorm_embedding(decoder_state)
|
|
||||||
decoder_state, attention_state = self.layers(
|
|
||||||
decoder_state,
|
|
||||||
encoder_state,
|
|
||||||
attention_state,
|
|
||||||
attention_mask,
|
|
||||||
token_index
|
|
||||||
)
|
|
||||||
decoder_state = self.final_ln(decoder_state)
|
|
||||||
decoder_state = self.lm_head(decoder_state)
|
|
||||||
return decoder_state, attention_state
|
|
||||||
|
|
||||||
def sample_image_tokens(
|
|
||||||
self,
|
|
||||||
text_tokens: jnp.ndarray,
|
|
||||||
encoder_state: jnp.ndarray,
|
|
||||||
prng_key: jax.random.PRNGKey,
|
|
||||||
params: dict
|
|
||||||
) -> jnp.ndarray:
|
|
||||||
attention_mask = jnp.not_equal(text_tokens, 1)
|
|
||||||
|
|
||||||
def sample_next_image_token(
|
|
||||||
state: SampleState,
|
|
||||||
token_index: int
|
|
||||||
) -> Tuple[SampleState, jnp.ndarray]:
|
|
||||||
logits, attention_state = self.apply(
|
|
||||||
{ 'params': params },
|
|
||||||
encoder_state = encoder_state,
|
|
||||||
attention_state = state.attention_state,
|
|
||||||
attention_mask = attention_mask,
|
|
||||||
prev_token = state.prev_token,
|
|
||||||
token_index = token_index
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = super_conditioned(logits, 10.0)
|
|
||||||
logits = keep_top_k(logits, k=50)
|
|
||||||
|
|
||||||
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
|
||||||
next_token = jax.random.categorical(prng_key, logits, axis=-1)
|
|
||||||
|
|
||||||
state = SampleState(
|
|
||||||
prev_token = next_token,
|
|
||||||
prng_key = prng_key_next,
|
|
||||||
attention_state = attention_state
|
|
||||||
)
|
|
||||||
|
|
||||||
return state, next_token
|
|
||||||
|
|
||||||
batch_count = encoder_state.shape[0]
|
|
||||||
attention_state_shape = (
|
|
||||||
self.layer_count,
|
|
||||||
batch_count * 2,
|
|
||||||
self.image_token_count,
|
|
||||||
self.embed_count
|
|
||||||
)
|
|
||||||
|
|
||||||
initial_state = SampleState(
|
|
||||||
prev_token = self.start_token,
|
|
||||||
prng_key = prng_key,
|
|
||||||
attention_state = jnp.zeros(attention_state_shape)
|
|
||||||
)
|
|
||||||
|
|
||||||
_, image_tokens = lax.scan(
|
|
||||||
sample_next_image_token,
|
|
||||||
initial_state,
|
|
||||||
jnp.arange(self.image_token_count)
|
|
||||||
)
|
|
||||||
|
|
||||||
return image_tokens
|
|
|
@ -1,151 +0,0 @@
|
||||||
from functools import partial
|
|
||||||
import jax
|
|
||||||
from jax import lax, numpy as jnp
|
|
||||||
from flax import linen as nn
|
|
||||||
|
|
||||||
|
|
||||||
class GLUFlax(nn.Module):
|
|
||||||
count_in_out: int
|
|
||||||
count_middle: int
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.gelu = partial(nn.gelu, approximate=False)
|
|
||||||
self.ln0 = nn.LayerNorm(use_scale=False)
|
|
||||||
self.ln1 = nn.LayerNorm(use_scale=False)
|
|
||||||
self.fc0 = nn.Dense(self.count_middle, use_bias=False)
|
|
||||||
self.fc1 = nn.Dense(self.count_middle, use_bias=False)
|
|
||||||
self.fc2 = nn.Dense(self.count_in_out, use_bias=False)
|
|
||||||
|
|
||||||
@nn.compact
|
|
||||||
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
|
|
||||||
z = self.ln0(z)
|
|
||||||
z = self.ln1(self.gelu(self.fc0(z)) * self.fc1(z))
|
|
||||||
z = self.fc2(z)
|
|
||||||
return z
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionFlax(nn.Module):
|
|
||||||
head_count: int
|
|
||||||
embed_count: int
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.q_proj = nn.Dense(self.embed_count, use_bias=False)
|
|
||||||
self.k_proj = nn.Dense(self.embed_count, use_bias=False)
|
|
||||||
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
|
|
||||||
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
keys: jnp.ndarray,
|
|
||||||
values: jnp.ndarray,
|
|
||||||
queries: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray
|
|
||||||
) -> jnp.ndarray:
|
|
||||||
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
|
|
||||||
values = values.reshape(values.shape[:2] + (self.head_count, -1))
|
|
||||||
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
attention_bias: jnp.ndarray = lax.select(
|
|
||||||
attention_mask,
|
|
||||||
jnp.full(attention_mask.shape, 0.0),
|
|
||||||
jnp.full(attention_mask.shape, -jnp.inf),
|
|
||||||
)
|
|
||||||
attention_weights: jnp.ndarray = jnp.einsum(
|
|
||||||
'bqhd,bkhd->bhqk',
|
|
||||||
queries,
|
|
||||||
keys
|
|
||||||
)
|
|
||||||
attention_weights += attention_bias[:, None, None, :]
|
|
||||||
attention_weights = jax.nn.softmax(attention_weights)
|
|
||||||
attention_output: jnp.ndarray = jnp.einsum(
|
|
||||||
"bhqk,bkhd->bqhd",
|
|
||||||
attention_weights,
|
|
||||||
values
|
|
||||||
)
|
|
||||||
shape = attention_output.shape[:2] + (self.embed_count,)
|
|
||||||
attention_output = attention_output.reshape(shape)
|
|
||||||
attention_output = self.out_proj(attention_output)
|
|
||||||
return attention_output
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderSelfAttentionFlax(AttentionFlax):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
encoder_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray
|
|
||||||
) -> jnp.ndarray:
|
|
||||||
keys = self.k_proj(encoder_state)
|
|
||||||
values = self.v_proj(encoder_state)
|
|
||||||
queries = self.q_proj(encoder_state)
|
|
||||||
return self.forward(keys, values, queries, attention_mask)
|
|
||||||
|
|
||||||
|
|
||||||
class DalleBartEncoderLayerFlax(nn.Module):
|
|
||||||
attention_head_count: int
|
|
||||||
embed_count: int
|
|
||||||
glu_embed_count: int
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False)
|
|
||||||
self.self_attn = EncoderSelfAttentionFlax(
|
|
||||||
self.attention_head_count,
|
|
||||||
self.embed_count
|
|
||||||
)
|
|
||||||
self.self_attn_layer_norm = nn.LayerNorm()
|
|
||||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
|
||||||
|
|
||||||
@nn.compact
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
encoder_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray
|
|
||||||
) -> jnp.ndarray:
|
|
||||||
residual = encoder_state
|
|
||||||
encoder_state = self.pre_self_attn_layer_norm(encoder_state)
|
|
||||||
encoder_state = self.self_attn(encoder_state, attention_mask)
|
|
||||||
encoder_state = self.self_attn_layer_norm(encoder_state)
|
|
||||||
encoder_state = residual + encoder_state
|
|
||||||
residual = encoder_state
|
|
||||||
encoder_state = self.glu(encoder_state)
|
|
||||||
encoder_state = residual + encoder_state
|
|
||||||
return encoder_state, None
|
|
||||||
|
|
||||||
|
|
||||||
class DalleBartEncoderFlax(nn.Module):
|
|
||||||
attention_head_count: int
|
|
||||||
embed_count: int
|
|
||||||
glu_embed_count: int
|
|
||||||
text_token_count: int
|
|
||||||
text_vocab_count: int
|
|
||||||
layer_count: int
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.embed_tokens = nn.Embed(self.text_vocab_count, self.embed_count)
|
|
||||||
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
|
|
||||||
self.layers = nn.scan(
|
|
||||||
DalleBartEncoderLayerFlax,
|
|
||||||
variable_axes = { "params": 0 },
|
|
||||||
split_rngs = { "params": True },
|
|
||||||
in_axes = nn.broadcast,
|
|
||||||
length = self.layer_count
|
|
||||||
)(
|
|
||||||
self.attention_head_count,
|
|
||||||
self.embed_count,
|
|
||||||
self.glu_embed_count,
|
|
||||||
name="FlaxBartEncoderLayers"
|
|
||||||
)
|
|
||||||
self.layernorm_embedding = nn.LayerNorm()
|
|
||||||
self.final_ln = nn.LayerNorm(use_scale=False)
|
|
||||||
|
|
||||||
def __call__(self, text_tokens: jnp.ndarray) -> jnp.ndarray:
|
|
||||||
batch_count, token_count = text_tokens.shape
|
|
||||||
pose_tokens = jnp.tile(jnp.arange(token_count), (batch_count, 1))
|
|
||||||
attention_mask = jnp.not_equal(text_tokens, 1)
|
|
||||||
encoder_state = (
|
|
||||||
self.embed_tokens(text_tokens) +
|
|
||||||
self.embed_positions(pose_tokens)
|
|
||||||
)
|
|
||||||
encoder_state = self.layernorm_embedding(encoder_state)
|
|
||||||
encoder_state, _ = self.layers(encoder_state, attention_mask)
|
|
||||||
encoder_state = self.final_ln(encoder_state)
|
|
||||||
return encoder_state
|
|
0
cog.yaml → replicate/cog.yaml
vendored
0
cog.yaml → replicate/cog.yaml
vendored
3
requirements_flax.txt
vendored
3
requirements_flax.txt
vendored
|
@ -1,3 +0,0 @@
|
||||||
flax
|
|
||||||
torch
|
|
||||||
wandb
|
|
29
setup.sh
vendored
29
setup.sh
vendored
|
@ -4,17 +4,22 @@ set -e
|
||||||
|
|
||||||
pip3 install -r requirements.txt
|
pip3 install -r requirements.txt
|
||||||
|
|
||||||
mkdir -p ./pretrained/dalle_bart_mega/
|
repo_path="https://huggingface.co/kuprel/min-dalle/resolve/main"
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab.json -L --output ./pretrained/dalle_bart_mega/vocab.json
|
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges.txt -L --output ./pretrained/dalle_bart_mega/merges.txt
|
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder.pt -L --output ./pretrained/dalle_bart_mega/encoder.pt
|
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder.pt -L --output ./pretrained/dalle_bart_mega/decoder.pt
|
|
||||||
|
|
||||||
mkdir -p ./pretrained/dalle_bart_mini/
|
mini_path="./pretrained/dalle_bart_mini"
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab_mini.json -L --output ./pretrained/dalle_bart_mini/vocab.json
|
mega_path="./pretrained/dalle_bart_mega"
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges_mini.txt -L --output ./pretrained/dalle_bart_mini/merges.txt
|
vqgan_path="./pretrained/vqgan"
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder_mini.pt -L --output ./pretrained/dalle_bart_mini/encoder.pt
|
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder_mini.pt -L --output ./pretrained/dalle_bart_mini/decoder.pt
|
|
||||||
|
|
||||||
mkdir -p ./pretrained/vqgan/
|
mkdir -p ${vqgan_path}
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/detoker.pt -L --output ./pretrained/vqgan/detoker.pt
|
mkdir -p ${mini_path}
|
||||||
|
mkdir -p ${mega_path}
|
||||||
|
|
||||||
|
curl ${repo_path}/detoker.pt -L --output ${vqgan_path}/detoker.pt
|
||||||
|
curl ${repo_path}/vocab_mini.json -L --output ${mini_path}/vocab.json
|
||||||
|
curl ${repo_path}/merges_mini.txt -L --output ${mini_path}/merges.txt
|
||||||
|
curl ${repo_path}/encoder_mini.pt -L --output ${mini_path}/encoder.pt
|
||||||
|
curl ${repo_path}/decoder_mini.pt -L --output ${mini_path}/decoder.pt
|
||||||
|
curl ${repo_path}/vocab.json -L --output ${mega_path}/vocab.json
|
||||||
|
curl ${repo_path}/merges.txt -L --output ${mega_path}/merges.txt
|
||||||
|
curl ${repo_path}/encoder.pt -L --output ${mega_path}/encoder.pt
|
||||||
|
curl ${repo_path}/decoder.pt -L --output ${mega_path}/decoder.pt
|
14
setup_flax.sh
vendored
14
setup_flax.sh
vendored
|
@ -1,14 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
pip3 install -r requirements_flax.txt
|
|
||||||
|
|
||||||
# download vqgan
|
|
||||||
mkdir -p pretrained/vqgan
|
|
||||||
curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output ./pretrained/vqgan/flax_model.msgpack
|
|
||||||
|
|
||||||
# download dalle-mini and dalle-mega
|
|
||||||
python3 -m wandb login --anonymously
|
|
||||||
python3 -m wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
|
|
||||||
python3 -m wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14
|
|
Loading…
Reference in New Issue
Block a user