Merge branch 'main' into patch-1

This commit is contained in:
Brett Kuprel 2022-07-01 02:40:38 -04:00 committed by GitHub
commit e3329a7f64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 593 additions and 488 deletions

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
* linguist-vendored
*.py linguist-vendored=false

19
README.md vendored
View File

@ -1,28 +1,33 @@
# min(DALL·E)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)  
[![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model.
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which are saved with `torch.save` the first time the model is loaded), and wandb is only used to download the models.
It currently takes **7.4 seconds** to generate an image with DALL·E Mega with PyTorch on a standard GPU runtime in Colab
### Setup
Run `sh setup.sh` to install dependencies and download pretrained models. In the bash script, Git LFS is used to download the VQGan detokenizer from Hugging Face and the Weight & Biases python package is used to download the DALL·E Mini and DALL·E Mega transformer models. These models can also be downloaded manually:
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually here:
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
### Usage
Use the command line python script `image_from_text.py` to generate images. Here are some examples:
Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize either `MinDalleTorch` or `MinDalleFlax`, then call `generate_image` with some text and a seed. See the colab for an example.
### Examples
```
python image_from_text.py --text='alien life' --seed=7
python image_from_text.py --text='artificial intelligence' --torch
```
![Alien](examples/alien.png)
![Alien](examples/artificial_intelligence.png)
```
python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=4
python image_from_text.py --text='a comfy chair that looks like an avocado' --torch --mega --seed=10
```
![Avocado Armchair](examples/avocado_armchair.png)

12
cog.yaml vendored Normal file
View File

@ -0,0 +1,12 @@
build:
cuda: "11.0"
gpu: true
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "torch==1.10.1"
- "flax==0.5.2"
predict: "predict.py:Predictor"

BIN
examples/alien.png vendored

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

BIN
examples/artificial_intelligence.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 127 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 90 KiB

After

Width:  |  Height:  |  Size: 101 KiB

View File

@ -2,8 +2,8 @@ import argparse
import os
from PIL import Image
from min_dalle.generate_image import generate_image_from_text
from min_dalle.min_dalle_torch import MinDalleTorch
from min_dalle.min_dalle_flax import MinDalleFlax
parser = argparse.ArgumentParser()
parser.add_argument('--mega', action='store_true')
@ -12,10 +12,10 @@ parser.set_defaults(mega=False)
parser.add_argument('--torch', action='store_true')
parser.add_argument('--no-torch', dest='torch', action='store_false')
parser.set_defaults(torch=False)
parser.add_argument('--text', type=str)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--text', type=str, default='alien life')
parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--image_path', type=str, default='generated')
parser.add_argument('--image_token_count', type=int, default=256) # for debugging
parser.add_argument('--token_count', type=int, default=256) # for debugging
def ascii_from_image(image: Image.Image, size: int) -> str:
@ -36,19 +36,41 @@ def save_image(image: Image.Image, path: str):
return image
def generate_image(
is_torch: bool,
is_mega: bool,
text: str,
seed: int,
image_path: str,
token_count: int
):
is_reusable = False
if is_torch:
image_generator = MinDalleTorch(is_mega, is_reusable, token_count)
if token_count < image_generator.config['image_length']:
image_tokens = image_generator.generate_image_tokens(text, seed)
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
return
else:
image = image_generator.generate_image(text, seed)
else:
image_generator = MinDalleFlax(is_mega, is_reusable)
image = image_generator.generate_image(text, seed)
save_image(image, image_path)
print(ascii_from_image(image, size=128))
if __name__ == '__main__':
args = parser.parse_args()
print(args)
image = generate_image_from_text(
text = args.text,
is_mega = args.mega,
is_torch = args.torch,
seed = args.seed,
image_token_count = args.image_token_count
generate_image(
is_torch=args.torch,
is_mega=args.mega,
text=args.text,
seed=args.seed,
image_path=args.image_path,
token_count=args.token_count
)
if image != None:
save_image(image, args.image_path)
print(ascii_from_image(image, size=128))

130
min_dalle.ipynb vendored

File diff suppressed because one or more lines are too long

View File

@ -1,77 +0,0 @@
import os
import json
import numpy
from PIL import Image
from typing import Tuple, List
from min_dalle.load_params import load_dalle_bart_flax_params
from min_dalle.text_tokenizer import TextTokenizer
from min_dalle.min_dalle_flax import generate_image_tokens_flax
from min_dalle.min_dalle_torch import (
generate_image_tokens_torch,
detokenize_torch
)
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
print("parsing metadata from {}".format(path))
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
assert(os.path.exists(os.path.join(path, f)))
with open(path + '/config.json', 'r') as f:
config = json.load(f)
with open(path + '/vocab.json') as f:
vocab = json.load(f)
with open(path + '/merges.txt') as f:
merges = f.read().split("\n")[1:-1]
return config, vocab, merges
def tokenize_text(
text: str,
config: dict,
vocab: dict,
merges: List[str]
) -> numpy.ndarray:
print("tokenizing text")
tokens = TextTokenizer(vocab, merges)(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens
def generate_image_from_text(
text: str,
is_mega: bool = False,
is_torch: bool = False,
seed: int = 0,
image_token_count: int = 256
) -> Image.Image:
model_name = 'mega' if is_mega else 'mini'
model_path = './pretrained/dalle_bart_{}'.format(model_name)
config, vocab, merges = load_dalle_bart_metadata(model_path)
text_tokens = tokenize_text(text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(model_path)
image_tokens = numpy.zeros(config['image_length'])
if is_torch:
image_tokens[:image_token_count] = generate_image_tokens_torch(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
image_token_count = image_token_count
)
else:
image_tokens[...] = generate_image_tokens_flax(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
)
if image_token_count == config['image_length']:
image = detokenize_torch(image_tokens)
return Image.fromarray(image)
else:
return None

View File

@ -1,17 +1,17 @@
import os
import numpy
from copy import deepcopy
from typing import Dict
from flax import traverse_util, serialization
from flax.traverse_util import flatten_dict
from flax.serialization import msgpack_restore
import torch
torch.no_grad()
torch.set_grad_enabled(False)
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read())
params: Dict[str, numpy.ndarray] = msgpack_restore(f.read())
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.')
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
for i in list(P.keys()):
j = i
@ -30,7 +30,6 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
for i in P:
P[i] = torch.tensor(P[i])
if torch.cuda.is_available(): P[i] = P[i].cuda()
P['embedding.weight'] = P.pop('quantize.embedding.embedding')
@ -43,7 +42,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
params = serialization.msgpack_restore(f.read())
params = msgpack_restore(f.read())
for codec in ['encoder', 'decoder']:
k = 'FlaxBart{}Layers'.format(codec.title())
@ -82,12 +81,10 @@ def convert_dalle_bart_torch_from_flax_params(
layer_count: int,
is_encoder: bool
) -> dict:
P = deepcopy(params)
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(P, sep='.')
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
for i in P:
P[i] = torch.tensor(P[i])
if torch.cuda.is_available(): P[i] = P[i].cuda()
P[i] = torch.tensor(P[i]).to(torch.float16)
for i in list(P):
if 'kernel' in i:
@ -108,3 +105,28 @@ def convert_dalle_bart_torch_from_flax_params(
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
P['embed_positions.weight'] = P.pop('embed_positions.embedding')
return P
def convert_and_save_torch_params(is_mega: bool, model_path: str):
print("converting params to torch")
layer_count = 24 if is_mega else 12
flax_params = load_dalle_bart_flax_params(model_path)
encoder_params = convert_dalle_bart_torch_from_flax_params(
flax_params['encoder'],
layer_count=layer_count,
is_encoder=True
)
decoder_params = convert_dalle_bart_torch_from_flax_params(
flax_params['decoder'],
layer_count=layer_count,
is_encoder=False
)
for i in decoder_params:
decoder_params[i] = decoder_params[i].to(torch.float16)
for i in encoder_params:
encoder_params[i] = encoder_params[i].to(torch.float16)
torch.save(encoder_params, os.path.join(model_path, 'encoder.pt'))
torch.save(decoder_params, os.path.join(model_path, 'decoder.pt'))

View File

@ -0,0 +1,46 @@
import os
import json
import numpy
from .text_tokenizer import TextTokenizer
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
from .models.vqgan_detokenizer import VQGanDetokenizer
class MinDalleBase:
def __init__(self, is_mega: bool):
self.is_mega = is_mega
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
self.model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(self.model_path))
config_path = os.path.join(self.model_path, 'config.json')
vocab_path = os.path.join(self.model_path, 'vocab.json')
merges_path = os.path.join(self.model_path, 'merges.txt')
with open(config_path, 'r', encoding='utf8') as f:
self.config = json.load(f)
with open(vocab_path, 'r', encoding='utf8') as f:
vocab = json.load(f)
with open(merges_path, 'r', encoding='utf8') as f:
merges = f.read().split("\n")[1:-1]
self.tokenizer = TextTokenizer(vocab, merges)
def init_detokenizer(self):
print("initializing VQGanDetokenizer")
params = load_vqgan_torch_params('./pretrained/vqgan')
self.detokenizer = VQGanDetokenizer()
self.detokenizer.load_state_dict(params)
del params
def tokenize_text(self, text: str) -> numpy.ndarray:
print("tokenizing text")
tokens = self.tokenizer.tokenize(text)
print("text tokens", tokens)
text_token_count = self.config['max_text_length']
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens

View File

@ -1,79 +1,80 @@
import jax
from jax import numpy as jnp
import numpy
from PIL import Image
import torch
from .min_dalle_base import MinDalleBase
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
def encode_flax(
text_tokens: numpy.ndarray,
config: dict,
params: dict
) -> jnp.ndarray:
print("loading flax encoder")
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = config['encoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['encoder_ffn_dim'],
text_token_count = config['max_text_length'],
text_vocab_count = config['encoder_vocab_size'],
layer_count = config['encoder_layers']
).bind({'params': params.pop('encoder')})
print("encoding text tokens")
encoder_state = encoder(text_tokens)
del encoder
return encoder_state
from .load_params import load_dalle_bart_flax_params
def decode_flax(
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
config: dict,
seed: int,
params: dict
) -> jnp.ndarray:
print("loading flax decoder")
decoder = DalleBartDecoderFlax(
image_token_count = config['image_length'],
text_token_count = config['max_text_length'],
image_vocab_count = config['image_vocab_size'],
attention_head_count = config['decoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
start_token = config['decoder_start_token_id']
class MinDalleFlax(MinDalleBase):
def __init__(self, is_mega: bool, is_reusable: bool = True):
super().__init__(is_mega)
self.is_reusable = is_reusable
print("initializing MinDalleFlax")
self.model_params = load_dalle_bart_flax_params(self.model_path)
if is_reusable:
self.init_encoder()
self.init_decoder()
self.init_detokenizer()
def init_encoder(self):
print("initializing DalleBartEncoderFlax")
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = self.config['encoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['encoder_ffn_dim'],
text_token_count = self.config['max_text_length'],
text_vocab_count = self.config['encoder_vocab_size'],
layer_count = self.config['encoder_layers']
).bind({'params': self.model_params.pop('encoder')})
def init_decoder(self):
print("initializing DalleBartDecoderFlax")
self.decoder = DalleBartDecoderFlax(
image_token_count = self.config['image_length'],
text_token_count = self.config['max_text_length'],
image_vocab_count = self.config['image_vocab_size'],
attention_head_count = self.config['decoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
start_token = self.config['decoder_start_token_id']
)
def generate_image(self, text: str, seed: int) -> Image.Image:
text_tokens = self.tokenize_text(text)
if not self.is_reusable: self.init_encoder()
print("encoding text tokens")
encoder_state = self.encoder(text_tokens)
if not self.is_reusable: del self.encoder
if not self.is_reusable:
self.init_decoder()
params = self.model_params.pop('decoder')
else:
params = self.model_params['decoder']
print("sampling image tokens")
image_tokens = decoder.sample_image_tokens(
image_tokens = self.decoder.sample_image_tokens(
text_tokens,
encoder_state,
jax.random.PRNGKey(seed),
params.pop('decoder')
)
del decoder
return image_tokens
def generate_image_tokens_flax(
text_tokens: numpy.ndarray,
seed: int,
config: dict,
params: dict
) -> numpy.ndarray:
encoder_state = encode_flax(
text_tokens,
config,
params
)
image_tokens = decode_flax(
text_tokens,
encoder_state,
config,
seed,
params
)
image_tokens = numpy.array(image_tokens)
print("image tokens", list(image_tokens))
return image_tokens
if not self.is_reusable: del self.decoder
image_tokens = torch.tensor(numpy.array(image_tokens))
if not self.is_reusable: self.init_detokenizer()
print("detokenizing image")
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
if not self.is_reusable: del self.detokenizer
image = Image.fromarray(image.to('cpu').detach().numpy())
return image

View File

@ -1,113 +1,114 @@
import numpy
import os
from PIL import Image
from typing import Dict
from torch import LongTensor, FloatTensor
from torch import LongTensor
import torch
torch.no_grad()
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
from .models.vqgan_detokenizer import VQGanDetokenizer
from .load_params import (
convert_and_save_torch_params,
load_dalle_bart_flax_params
)
from .min_dalle_base import MinDalleBase
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
from .load_params import (
load_vqgan_torch_params,
convert_dalle_bart_torch_from_flax_params
)
class MinDalleTorch(MinDalleBase):
def __init__(
self,
is_mega: bool,
is_reusable: bool = True,
token_count: int = 256
):
print("initializing MinDalleTorch")
super().__init__(is_mega)
self.is_reusable = is_reusable
self.token_count = token_count
if not is_mega:
self.model_params = load_dalle_bart_flax_params(self.model_path)
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
is_converted = os.path.exists(self.encoder_params_path)
is_converted &= os.path.exists(self.decoder_params_path)
if not is_converted:
convert_and_save_torch_params(is_mega, self.model_path)
if is_reusable:
self.init_encoder()
self.init_decoder()
self.init_detokenizer()
def encode_torch(
text_tokens: LongTensor,
config: dict,
params: dict
) -> FloatTensor:
print("loading torch encoder")
encoder = DalleBartEncoderTorch(
layer_count = config['encoder_layers'],
embed_count = config['d_model'],
attention_head_count = config['encoder_attention_heads'],
text_vocab_count = config['encoder_vocab_size'],
text_token_count = config['max_text_length'],
glu_embed_count = config['encoder_ffn_dim']
def init_encoder(self):
print("initializing DalleBartEncoderTorch")
self.encoder = DalleBartEncoderTorch(
layer_count = self.config['encoder_layers'],
embed_count = self.config['d_model'],
attention_head_count = self.config['encoder_attention_heads'],
text_vocab_count = self.config['encoder_vocab_size'],
text_token_count = self.config['max_text_length'],
glu_embed_count = self.config['encoder_ffn_dim']
)
encoder_params = convert_dalle_bart_torch_from_flax_params(
params.pop('encoder'),
layer_count=config['encoder_layers'],
is_encoder=True
)
encoder.load_state_dict(encoder_params, strict=False)
del encoder_params
print("encoding text tokens")
encoder_state = encoder(text_tokens)
del encoder
return encoder_state
params = torch.load(self.encoder_params_path)
self.encoder.load_state_dict(params, strict=False)
del params
if torch.cuda.is_available(): self.encoder = self.encoder.cuda()
def decode_torch(
text_tokens: LongTensor,
encoder_state: FloatTensor,
config: dict,
seed: int,
params: dict,
image_token_count: int
) -> LongTensor:
print("loading torch decoder")
decoder = DalleBartDecoderTorch(
image_vocab_size = config['image_vocab_size'],
image_token_count = config['image_length'],
sample_token_count = image_token_count,
embed_count = config['d_model'],
attention_head_count = config['decoder_attention_heads'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
def init_decoder(self):
print("initializing DalleBartDecoderTorch")
self.decoder = DalleBartDecoderTorch(
image_vocab_size = self.config['image_vocab_size'],
image_token_count = self.config['image_length'],
sample_token_count = self.token_count,
embed_count = self.config['d_model'],
attention_head_count = self.config['decoder_attention_heads'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
batch_count = 2,
start_token = config['decoder_start_token_id'],
start_token = self.config['decoder_start_token_id'],
is_verbose = True
)
decoder_params = convert_dalle_bart_torch_from_flax_params(
params.pop('decoder'),
layer_count=config['decoder_layers'],
is_encoder=False
)
decoder.load_state_dict(decoder_params, strict=False)
del decoder_params
print("sampling image tokens")
torch.manual_seed(seed)
image_tokens = decoder.forward(text_tokens, encoder_state)
return image_tokens
params = torch.load(self.decoder_params_path)
self.decoder.load_state_dict(params, strict=False)
del params
if torch.cuda.is_available(): self.decoder = self.decoder.cuda()
def generate_image_tokens_torch(
text_tokens: numpy.ndarray,
seed: int,
config: dict,
params: dict,
image_token_count: int
) -> LongTensor:
def init_detokenizer(self):
super().init_detokenizer()
if torch.cuda.is_available():
self.detokenizer = self.detokenizer.cuda()
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
text_tokens = self.tokenize_text(text)
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
encoder_state = encode_torch(
text_tokens,
config,
params
)
image_tokens = decode_torch(
text_tokens,
encoder_state,
config,
seed,
params,
image_token_count
)
if not self.is_reusable: self.init_encoder()
print("encoding text tokens")
encoder_state = self.encoder.forward(text_tokens)
if not self.is_reusable: del self.encoder
if not self.is_reusable: self.init_decoder()
print("sampling image tokens")
torch.manual_seed(seed)
image_tokens = self.decoder.forward(text_tokens, encoder_state)
if not self.is_reusable: del self.decoder
return image_tokens
def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray:
def generate_image(self, text: str, seed: int) -> Image.Image:
image_tokens = self.generate_image_tokens(text, seed)
if not self.is_reusable: self.init_detokenizer()
print("detokenizing image")
model_path = './pretrained/vqgan'
params = load_vqgan_torch_params(model_path)
detokenizer = VQGanDetokenizer()
detokenizer.load_state_dict(params)
image = detokenizer.forward(image_tokens).to(torch.uint8)
return image.detach().numpy()
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
if not self.is_reusable: del self.detokenizer
image = Image.fromarray(image.to('cpu').detach().numpy())
return image

View File

@ -13,46 +13,39 @@ class DecoderCrossAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray,
) -> jnp.ndarray:
keys: jnp.ndarray = self.k_proj(encoder_state)
values: jnp.ndarray = self.v_proj(encoder_state)
queries: jnp.ndarray = self.q_proj(decoder_state)
query_shape = queries.shape[:2] + (self.head_count, -1)
key_value_shape = keys.shape[:2] + (self.head_count, -1)
keys = keys.reshape(key_value_shape)
values = values.reshape(key_value_shape)
queries = queries.reshape(query_shape)
queries /= queries.shape[-1] ** 0.5
keys = self.k_proj(encoder_state)
values = self.v_proj(encoder_state)
queries = self.q_proj(decoder_state)
return self.forward(keys, values, queries, attention_mask)
class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(self,
def __call__(
self,
decoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
state_index: tuple
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
shape_split = decoder_state.shape[:2] + (self.head_count, -1)
keys_state = lax.dynamic_update_slice(
keys_state,
self.k_proj(decoder_state).reshape(shape_split),
) -> Tuple[jnp.ndarray, jnp.ndarray]:
keys = self.k_proj(decoder_state)
values = self.v_proj(decoder_state)
queries = self.q_proj(decoder_state)
attention_state = lax.dynamic_update_slice(
attention_state,
jnp.concatenate([keys, values]),
state_index
)
values_state = lax.dynamic_update_slice(
values_state,
self.v_proj(decoder_state).reshape(shape_split),
state_index
)
queries = self.q_proj(decoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
batch_count = decoder_state.shape[0]
keys, values = attention_state[:batch_count], attention_state[batch_count:]
decoder_state = self.forward(
keys_state,
values_state,
keys,
values,
queries,
attention_mask
)
return decoder_state, (keys_state, values_state)
return decoder_state, attention_state
class DalleBartDecoderLayerFlax(nn.Module):
@ -77,14 +70,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
def __call__(
self,
decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
token_index: int
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
) -> Tuple[jnp.ndarray, jnp.ndarray]:
# Self Attention
residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
@ -92,12 +85,11 @@ class DalleBartDecoderLayerFlax(nn.Module):
jnp.arange(self.image_token_count) < token_index + 1,
(decoder_state.shape[0], 1)
)
decoder_state, keys_values_state = self.self_attn(
decoder_state, attention_state = self.self_attn(
decoder_state,
keys_state,
values_state,
attention_state,
self_attention_mask,
(0, token_index, 0, 0)
(0, token_index, 0)
)
decoder_state = self.self_attn_layer_norm(decoder_state)
decoder_state = residual + decoder_state
@ -118,15 +110,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
decoder_state = self.glu(decoder_state)
decoder_state = residual + decoder_state
return decoder_state, keys_values_state
return decoder_state, attention_state
@flax.struct.dataclass
class SampleState:
prev_token: jnp.ndarray
prng_key: jnp.ndarray
keys_state: jnp.ndarray
values_state: jnp.ndarray
attention_state: jnp.ndarray
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
return a * logits[0, -1] + (1 - a) * logits[1, -1]
@ -157,10 +148,10 @@ class DalleBartDecoderFlax(nn.Module):
)
self.layers = nn.scan(
DalleBartDecoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 },
variable_axes = { "params": 0 },
split_rngs = { "params": True },
in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast),
out_axes = (0, 0),
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
out_axes = 0,
length=self.layer_count,
)(
self.image_token_count,
@ -173,32 +164,32 @@ class DalleBartDecoderFlax(nn.Module):
self.final_ln = nn.LayerNorm(use_scale=False)
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
def __call__(self,
def __call__(
self,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_state: jnp.ndarray,
attention_mask: jnp.ndarray,
prev_token: int,
token_index: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
) -> Tuple[jnp.ndarray, jnp.ndarray]:
batch_count = encoder_state.shape[0]
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
decoder_state = self.embed_tokens(prev_token * ones)
decoder_state += self.embed_positions(token_index * ones)
decoder_state = self.layernorm_embedding(decoder_state)
decoder_state, (keys_state, values_state) = self.layers(
decoder_state, attention_state = self.layers(
decoder_state,
encoder_state,
keys_state,
values_state,
attention_state,
attention_mask,
token_index
)
decoder_state = self.final_ln(decoder_state)
decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state
return decoder_state, attention_state
def sample_image_tokens(self,
def sample_image_tokens(
self,
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
prng_key: jax.random.PRNGKey,
@ -209,12 +200,11 @@ class DalleBartDecoderFlax(nn.Module):
def sample_next_image_token(
state: SampleState,
token_index: int
) -> Tuple[SampleState, None]:
logits, keys_state, values_state = self.apply(
) -> Tuple[SampleState, jnp.ndarray]:
logits, attention_state = self.apply(
{ 'params': params },
encoder_state = encoder_state,
keys_state = state.keys_state,
values_state = state.values_state,
attention_state = state.attention_state,
attention_mask = attention_mask,
prev_token = state.prev_token,
token_index = token_index
@ -229,26 +219,23 @@ class DalleBartDecoderFlax(nn.Module):
state = SampleState(
prev_token = next_token,
prng_key = prng_key_next,
keys_state = keys_state,
values_state = values_state
attention_state = attention_state
)
return state, next_token
batch_count = encoder_state.shape[0]
state_shape = (
attention_state_shape = (
self.layer_count,
batch_count,
batch_count * 2,
self.image_token_count,
self.attention_head_count,
self.embed_count // self.attention_head_count
self.embed_count
)
initial_state = SampleState(
prev_token = self.start_token,
prng_key = prng_key,
keys_state = jnp.zeros(state_shape),
values_state = jnp.zeros(state_shape)
attention_state = jnp.zeros(attention_state_shape)
)
_, image_tokens = lax.scan(

View File

@ -1,7 +1,7 @@
from typing import List, Tuple
import torch
from torch import LongTensor, nn, FloatTensor, BoolTensor
torch.no_grad()
torch.set_grad_enabled(False)
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
@ -16,42 +16,35 @@ class DecoderCrossAttentionTorch(AttentionTorch):
keys = self.k_proj.forward(encoder_state)
values = self.v_proj.forward(encoder_state)
queries = self.q_proj.forward(decoder_state)
query_shape = queries.shape[:2] + (self.head_count, -1)
key_value_shape = keys.shape[:2] + (self.head_count, -1)
keys = keys.reshape(key_value_shape)
values = values.reshape(key_value_shape)
queries = queries.reshape(query_shape)
queries /= queries.shape[-1] ** 0.5
return super().forward(keys, values, queries, attention_mask)
class DecoderSelfAttentionTorch(AttentionTorch):
def forward(self,
def forward(
self,
decoder_state: FloatTensor,
keys_values: FloatTensor,
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_index: LongTensor
token_mask: BoolTensor
) -> Tuple[FloatTensor, FloatTensor]:
batch_count = decoder_state.shape[0]
token_count = keys_values.shape[1]
shape = (batch_count, 1) + keys_values.shape[2:]
keys = self.k_proj.forward(decoder_state).view(shape)
values = self.v_proj.forward(decoder_state).view(shape)
token_mask = torch.arange(token_count) == token_index
keys_values = torch.where(
token_mask[None, :, None, None],
keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state)
attention_state = torch.where(
token_mask[None, :, None],
torch.cat([keys, values]),
keys_values
attention_state
)
queries = self.q_proj.forward(decoder_state).reshape(shape)
queries /= queries.shape[-1] ** 0.5
keys, values = keys_values[:batch_count], keys_values[batch_count:]
batch_count = decoder_state.shape[0]
keys = attention_state[:batch_count]
values = attention_state[batch_count:]
decoder_state = super().forward(keys, values, queries, attention_mask)
return decoder_state, keys_values
return decoder_state, attention_state
class DecoderLayerTorch(nn.Module):
def __init__(self,
def __init__(
self,
image_token_count: int,
head_count: int,
embed_count: int,
@ -67,23 +60,29 @@ class DecoderLayerTorch(nn.Module):
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.glu = GLUTorch(embed_count, glu_embed_count)
def forward(self,
self.token_indices = torch.arange(self.image_token_count)
if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda()
def forward(
self,
decoder_state: FloatTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]:
# Self Attention
residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
self_attn_mask = torch.arange(self.image_token_count) < token_index + 1
self_attn_mask = self.token_indices < token_index + 1
token_mask = self.token_indices == token_index
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
decoder_state, keys_values_state = self.self_attn.forward(
decoder_state, attention_state = self.self_attn.forward(
decoder_state,
keys_values_state,
attention_state,
self_attn_mask,
token_index
token_mask
)
decoder_state = self.self_attn_layer_norm.forward(decoder_state)
decoder_state = residual + decoder_state
@ -104,11 +103,12 @@ class DecoderLayerTorch(nn.Module):
decoder_state = self.glu.forward(decoder_state)
decoder_state = residual + decoder_state
return decoder_state, keys_values_state
return decoder_state, attention_state
class DalleBartDecoderTorch(nn.Module):
def __init__(self,
def __init__(
self,
image_vocab_size: int,
image_token_count: int,
sample_token_count: int,
@ -124,13 +124,7 @@ class DalleBartDecoderTorch(nn.Module):
self.is_verbose = is_verbose
self.layer_count = layer_count
self.sample_token_count = sample_token_count
self.start_token = torch.tensor([start_token]).to(torch.long)
self.pad_token = torch.tensor([1]).to(torch.long)
self.condition_factor = torch.tensor([10]).to(torch.float)
if torch.cuda.is_available():
self.start_token = self.start_token.cuda()
self.pad_token = self.pad_token.cuda()
self.condition_factor = self.condition_factor.cuda()
self.condition_factor = 10.0
self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count)
@ -146,77 +140,82 @@ class DalleBartDecoderTorch(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
self.keys_values_state_shape = (
layer_count * 2 * batch_count,
self.attention_state_shape = (
layer_count,
2 * batch_count,
image_token_count,
attention_head_count,
embed_count // attention_head_count
embed_count
)
self.zero_prob = torch.zeros([1])
self.token_indices = torch.arange(self.sample_token_count)
self.start_token = torch.tensor([start_token]).to(torch.long)
if torch.cuda.is_available():
self.zero_prob = self.zero_prob.cuda()
self.token_indices = self.token_indices.cuda()
self.start_token = self.start_token.cuda()
def decode_step(self,
def decode_step(
self,
text_tokens: LongTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
prev_token_and_index: LongTensor
attention_state: FloatTensor,
prev_token: LongTensor,
token_index: LongTensor
) -> Tuple[LongTensor, FloatTensor]:
attention_mask = text_tokens.not_equal(self.pad_token)
attention_mask = text_tokens.not_equal(1)
batch_count = encoder_state.shape[0]
prev_token = torch.cat([prev_token_and_index[:1]] * batch_count)
token_index = torch.cat([prev_token_and_index[1:]] * batch_count)
decoder_state = self.embed_tokens.forward(prev_token)
decoder_state += self.embed_positions.forward(token_index)
prev_token_batched = torch.cat([prev_token] * batch_count)
token_index_batched = torch.cat([token_index] * batch_count)
decoder_state = self.embed_tokens.forward(prev_token_batched)
decoder_state += self.embed_positions.forward(token_index_batched)
decoder_state = self.layernorm_embedding.forward(decoder_state)
decoder_state = decoder_state[:, None]
keys_values = []
for i, layer in enumerate(self.layers):
j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count
decoder_state, keys_values_layer = layer.forward(
attention_states_new = []
for i in range(self.layer_count):
decoder_state, attention_state_layer = self.layers[i].forward(
decoder_state,
encoder_state,
keys_values_state[j1:j2],
attention_state[i],
attention_mask,
token_index[:1]
token_index
)
keys_values.append(keys_values_layer)
keys_values = torch.cat(keys_values, dim=0)
attention_states_new.append(attention_state_layer)
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
a = self.condition_factor
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
top_logits = logits.sort(descending=True)[0][:50]
top_logits, _ = logits.topk(50, dim=-1)
probs = torch.where(
logits < top_logits[-1],
torch.zeros([1]),
self.zero_prob,
torch.exp(logits - top_logits[0])
)
return probs, keys_values
return probs, torch.stack(attention_states_new)
def forward(self,
def forward(
self,
text_tokens: LongTensor,
encoder_state: FloatTensor
) -> LongTensor:
image_tokens: List[LongTensor] = []
keys_values_state = torch.zeros(self.keys_values_state_shape)
attention_state = torch.zeros(self.attention_state_shape)
if torch.cuda.is_available():
attention_state = attention_state.cuda()
image_token = self.start_token
for i in range(self.sample_token_count):
token_index = torch.tensor([i]).to(torch.long)
if torch.cuda.is_available(): token_index = token_index.cuda()
probs, keys_values_state = self.decode_step(
probs, attention_state = self.decode_step(
text_tokens = text_tokens,
encoder_state = encoder_state,
keys_values_state = keys_values_state,
prev_token_and_index = torch.cat([image_token, token_index])
attention_state = attention_state,
prev_token = image_token,
token_index = self.token_indices[[i]]
)
image_token = torch.multinomial(probs, 1)
image_tokens += [image_token]
if self.is_verbose:
token = int(image_token.detach().numpy())
print("image token {} is {}".format(i, token))
return torch.cat(image_tokens)

View File

@ -34,12 +34,17 @@ class AttentionFlax(nn.Module):
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
def forward(self,
def forward(
self,
keys: jnp.ndarray,
values: jnp.ndarray,
queries: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias: jnp.ndarray = lax.select(
attention_mask,
jnp.full(attention_mask.shape, 0.0),
@ -69,11 +74,9 @@ class EncoderSelfAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
keys = self.k_proj(encoder_state).reshape(shape_split)
values = self.v_proj(encoder_state).reshape(shape_split)
queries = self.q_proj(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
keys = self.k_proj(encoder_state)
values = self.v_proj(encoder_state)
queries = self.q_proj(encoder_state)
return self.forward(keys, values, queries, attention_mask)
@ -92,7 +95,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
def __call__(
self,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:
@ -120,7 +124,7 @@ class DalleBartEncoderFlax(nn.Module):
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
self.layers = nn.scan(
DalleBartEncoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 },
variable_axes = { "params": 0 },
split_rngs = { "params": True },
in_axes = nn.broadcast,
length = self.layer_count

View File

@ -1,7 +1,7 @@
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.no_grad()
torch.set_grad_enabled(False)
class GLUTorch(nn.Module):
@ -34,17 +34,25 @@ class AttentionTorch(nn.Module):
self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
self.one = torch.ones((1, 1))
if torch.cuda.is_available(): self.one = self.one.cuda()
def forward(self,
def forward(
self,
keys: FloatTensor,
values: FloatTensor,
queries: FloatTensor,
attention_mask: BoolTensor
) -> FloatTensor:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias = torch.where(
attention_mask,
torch.full(attention_mask.shape, 0.0),
torch.full(attention_mask.shape, -torch.inf),
self.one * 0,
self.one * (-torch.inf),
)
attention_weights: FloatTensor = torch.einsum(
'bqhc,bkhc->bhqk',
@ -70,11 +78,9 @@ class EncoderSelfAttentionTorch(AttentionTorch):
encoder_state: FloatTensor,
attention_mask: BoolTensor
) -> FloatTensor:
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
keys = self.k_proj.forward(encoder_state).reshape(shape_split)
values = self.v_proj.forward(encoder_state).reshape(shape_split)
queries = self.q_proj.forward(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
keys = self.k_proj.forward(encoder_state)
values = self.v_proj.forward(encoder_state)
queries = self.q_proj.forward(encoder_state)
return super().forward(keys, values, queries, attention_mask)
@ -103,7 +109,8 @@ class EncoderLayerTorch(nn.Module):
class DalleBartEncoderTorch(nn.Module):
def __init__(self,
def __init__(
self,
layer_count: int,
embed_count: int,
attention_head_count: int,
@ -124,11 +131,14 @@ class DalleBartEncoderTorch(nn.Module):
])
self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count)
self.token_indices = torch.arange(text_token_count).to(torch.long)
if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda()
def forward(self, text_tokens: LongTensor) -> FloatTensor:
attention_mask = text_tokens.not_equal(1)
batch_count, token_count = text_tokens.shape
pose_tokens = torch.stack([torch.arange(token_count)] * batch_count)
batch_count = text_tokens.shape[0]
pose_tokens = torch.stack([self.token_indices] * batch_count)
encoder_state = (
self.embed_tokens.forward(text_tokens) +
self.embed_positions.forward(pose_tokens)

View File

@ -1,7 +1,7 @@
import torch
from torch import Tensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.no_grad()
torch.set_grad_enabled(False)
BATCH_COUNT: int = 1
@ -61,6 +61,7 @@ class AttentionBlock(Module):
h = self.proj_out.forward(h)
return x + h
class MiddleLayer(Module):
def __init__(self):
super().__init__()
@ -74,6 +75,7 @@ class MiddleLayer(Module):
h = self.block_2.forward(h)
return h
class Upsample(Module):
def __init__(self, log2_count):
super().__init__()
@ -86,6 +88,7 @@ class Upsample(Module):
x = self.conv.forward(x)
return x
class UpsampleBlock(Module):
def __init__(
self,
@ -124,6 +127,7 @@ class UpsampleBlock(Module):
h = self.upsample.forward(h)
return h
class Decoder(Module):
def __init__(self):
super().__init__()
@ -154,6 +158,7 @@ class Decoder(Module):
z = self.conv_out.forward(z)
return z
class VQGanDetokenizer(Module):
def __init__(self):
super().__init__()

View File

@ -8,7 +8,7 @@ class TextTokenizer:
pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
def __call__(self, text: str) -> List[int]:
def tokenize(self, text: str) -> List[int]:
sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>']

23
predict.py Normal file
View File

@ -0,0 +1,23 @@
import tempfile
from cog import BasePredictor, Path, Input
from min_dalle.min_dalle_torch import MinDalleTorch
class Predictor(BasePredictor):
def setup(self):
self.model = MinDalleTorch(is_mega=True)
def predict(
self,
text: str = Input(
description="Text for generating images.",
),
seed: int = Input(
description="Specify the seed.",
),
) -> Path:
image = self.model.generate_image(text, seed)
out_path = Path(tempfile.mkdtemp()) / "output.png"
image.save(str(out_path))
return out_path

1
requirements.txt vendored
View File

@ -1,2 +1,3 @@
torch
flax==0.4.2
wandb

10
setup.sh vendored
View File

@ -1,15 +1,15 @@
#!/bin/bash
set -e
pip install -r requirements.txt
mkdir -p pretrained
mkdir -p pretrained/vqgan
# download vqgan
git lfs install
git clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384 ./pretrained/vqgan
curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output ./pretrained/vqgan/flax_model.msgpack
# download dalle-mini and dalle mega
pip install wandb
python -m wandb login
python -m wandb login --anonymously
python -m wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
python -m wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14