Merge branch 'main' into patch-1

This commit is contained in:
Brett Kuprel 2022-07-01 02:40:38 -04:00 committed by GitHub
commit e3329a7f64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 593 additions and 488 deletions

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
* linguist-vendored
*.py linguist-vendored=false

19
README.md vendored
View File

@ -1,28 +1,33 @@
# min(DALL·E) # min(DALL·E)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)  
[![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model. This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which are saved with `torch.save` the first time the model is loaded), and wandb is only used to download the models.
It currently takes **7.4 seconds** to generate an image with DALL·E Mega with PyTorch on a standard GPU runtime in Colab
### Setup ### Setup
Run `sh setup.sh` to install dependencies and download pretrained models. In the bash script, Git LFS is used to download the VQGan detokenizer from Hugging Face and the Weight & Biases python package is used to download the DALL·E Mini and DALL·E Mega transformer models. These models can also be downloaded manually: Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually here:
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384), [VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files), [DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files) [DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
### Usage ### Usage
Use the command line python script `image_from_text.py` to generate images. Here are some examples: Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize either `MinDalleTorch` or `MinDalleFlax`, then call `generate_image` with some text and a seed. See the colab for an example.
### Examples
``` ```
python image_from_text.py --text='alien life' --seed=7 python image_from_text.py --text='artificial intelligence' --torch
``` ```
![Alien](examples/alien.png) ![Alien](examples/artificial_intelligence.png)
``` ```
python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=4 python image_from_text.py --text='a comfy chair that looks like an avocado' --torch --mega --seed=10
``` ```
![Avocado Armchair](examples/avocado_armchair.png) ![Avocado Armchair](examples/avocado_armchair.png)

12
cog.yaml vendored Normal file
View File

@ -0,0 +1,12 @@
build:
cuda: "11.0"
gpu: true
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "torch==1.10.1"
- "flax==0.5.2"
predict: "predict.py:Predictor"

BIN
examples/alien.png vendored

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

BIN
examples/artificial_intelligence.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 127 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 90 KiB

After

Width:  |  Height:  |  Size: 101 KiB

View File

@ -2,8 +2,8 @@ import argparse
import os import os
from PIL import Image from PIL import Image
from min_dalle.generate_image import generate_image_from_text from min_dalle.min_dalle_torch import MinDalleTorch
from min_dalle.min_dalle_flax import MinDalleFlax
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--mega', action='store_true') parser.add_argument('--mega', action='store_true')
@ -12,10 +12,10 @@ parser.set_defaults(mega=False)
parser.add_argument('--torch', action='store_true') parser.add_argument('--torch', action='store_true')
parser.add_argument('--no-torch', dest='torch', action='store_false') parser.add_argument('--no-torch', dest='torch', action='store_false')
parser.set_defaults(torch=False) parser.set_defaults(torch=False)
parser.add_argument('--text', type=str) parser.add_argument('--text', type=str, default='alien life')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--image_path', type=str, default='generated') parser.add_argument('--image_path', type=str, default='generated')
parser.add_argument('--image_token_count', type=int, default=256) # for debugging parser.add_argument('--token_count', type=int, default=256) # for debugging
def ascii_from_image(image: Image.Image, size: int) -> str: def ascii_from_image(image: Image.Image, size: int) -> str:
@ -36,19 +36,41 @@ def save_image(image: Image.Image, path: str):
return image return image
def generate_image(
is_torch: bool,
is_mega: bool,
text: str,
seed: int,
image_path: str,
token_count: int
):
is_reusable = False
if is_torch:
image_generator = MinDalleTorch(is_mega, is_reusable, token_count)
if token_count < image_generator.config['image_length']:
image_tokens = image_generator.generate_image_tokens(text, seed)
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
return
else:
image = image_generator.generate_image(text, seed)
else:
image_generator = MinDalleFlax(is_mega, is_reusable)
image = image_generator.generate_image(text, seed)
save_image(image, image_path)
print(ascii_from_image(image, size=128))
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
generate_image(
image = generate_image_from_text(
text = args.text,
is_mega = args.mega,
is_torch=args.torch, is_torch=args.torch,
is_mega=args.mega,
text=args.text,
seed=args.seed, seed=args.seed,
image_token_count = args.image_token_count image_path=args.image_path,
token_count=args.token_count
) )
if image != None:
save_image(image, args.image_path)
print(ascii_from_image(image, size=128))

130
min_dalle.ipynb vendored

File diff suppressed because one or more lines are too long

View File

@ -1,77 +0,0 @@
import os
import json
import numpy
from PIL import Image
from typing import Tuple, List
from min_dalle.load_params import load_dalle_bart_flax_params
from min_dalle.text_tokenizer import TextTokenizer
from min_dalle.min_dalle_flax import generate_image_tokens_flax
from min_dalle.min_dalle_torch import (
generate_image_tokens_torch,
detokenize_torch
)
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
print("parsing metadata from {}".format(path))
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
assert(os.path.exists(os.path.join(path, f)))
with open(path + '/config.json', 'r') as f:
config = json.load(f)
with open(path + '/vocab.json') as f:
vocab = json.load(f)
with open(path + '/merges.txt') as f:
merges = f.read().split("\n")[1:-1]
return config, vocab, merges
def tokenize_text(
text: str,
config: dict,
vocab: dict,
merges: List[str]
) -> numpy.ndarray:
print("tokenizing text")
tokens = TextTokenizer(vocab, merges)(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens
def generate_image_from_text(
text: str,
is_mega: bool = False,
is_torch: bool = False,
seed: int = 0,
image_token_count: int = 256
) -> Image.Image:
model_name = 'mega' if is_mega else 'mini'
model_path = './pretrained/dalle_bart_{}'.format(model_name)
config, vocab, merges = load_dalle_bart_metadata(model_path)
text_tokens = tokenize_text(text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(model_path)
image_tokens = numpy.zeros(config['image_length'])
if is_torch:
image_tokens[:image_token_count] = generate_image_tokens_torch(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
image_token_count = image_token_count
)
else:
image_tokens[...] = generate_image_tokens_flax(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
)
if image_token_count == config['image_length']:
image = detokenize_torch(image_tokens)
return Image.fromarray(image)
else:
return None

View File

@ -1,17 +1,17 @@
import os import os
import numpy import numpy
from copy import deepcopy
from typing import Dict from typing import Dict
from flax import traverse_util, serialization from flax.traverse_util import flatten_dict
from flax.serialization import msgpack_restore
import torch import torch
torch.no_grad() torch.set_grad_enabled(False)
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]: def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f: with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read()) params: Dict[str, numpy.ndarray] = msgpack_restore(f.read())
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.') P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
for i in list(P.keys()): for i in list(P.keys()):
j = i j = i
@ -30,7 +30,6 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
for i in P: for i in P:
P[i] = torch.tensor(P[i]) P[i] = torch.tensor(P[i])
if torch.cuda.is_available(): P[i] = P[i].cuda()
P['embedding.weight'] = P.pop('quantize.embedding.embedding') P['embedding.weight'] = P.pop('quantize.embedding.embedding')
@ -43,7 +42,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]: def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f: with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
params = serialization.msgpack_restore(f.read()) params = msgpack_restore(f.read())
for codec in ['encoder', 'decoder']: for codec in ['encoder', 'decoder']:
k = 'FlaxBart{}Layers'.format(codec.title()) k = 'FlaxBart{}Layers'.format(codec.title())
@ -82,12 +81,10 @@ def convert_dalle_bart_torch_from_flax_params(
layer_count: int, layer_count: int,
is_encoder: bool is_encoder: bool
) -> dict: ) -> dict:
P = deepcopy(params) P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(P, sep='.')
for i in P: for i in P:
P[i] = torch.tensor(P[i]) P[i] = torch.tensor(P[i]).to(torch.float16)
if torch.cuda.is_available(): P[i] = P[i].cuda()
for i in list(P): for i in list(P):
if 'kernel' in i: if 'kernel' in i:
@ -108,3 +105,28 @@ def convert_dalle_bart_torch_from_flax_params(
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding') P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
P['embed_positions.weight'] = P.pop('embed_positions.embedding') P['embed_positions.weight'] = P.pop('embed_positions.embedding')
return P return P
def convert_and_save_torch_params(is_mega: bool, model_path: str):
print("converting params to torch")
layer_count = 24 if is_mega else 12
flax_params = load_dalle_bart_flax_params(model_path)
encoder_params = convert_dalle_bart_torch_from_flax_params(
flax_params['encoder'],
layer_count=layer_count,
is_encoder=True
)
decoder_params = convert_dalle_bart_torch_from_flax_params(
flax_params['decoder'],
layer_count=layer_count,
is_encoder=False
)
for i in decoder_params:
decoder_params[i] = decoder_params[i].to(torch.float16)
for i in encoder_params:
encoder_params[i] = encoder_params[i].to(torch.float16)
torch.save(encoder_params, os.path.join(model_path, 'encoder.pt'))
torch.save(decoder_params, os.path.join(model_path, 'decoder.pt'))

View File

@ -0,0 +1,46 @@
import os
import json
import numpy
from .text_tokenizer import TextTokenizer
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
from .models.vqgan_detokenizer import VQGanDetokenizer
class MinDalleBase:
def __init__(self, is_mega: bool):
self.is_mega = is_mega
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
self.model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(self.model_path))
config_path = os.path.join(self.model_path, 'config.json')
vocab_path = os.path.join(self.model_path, 'vocab.json')
merges_path = os.path.join(self.model_path, 'merges.txt')
with open(config_path, 'r', encoding='utf8') as f:
self.config = json.load(f)
with open(vocab_path, 'r', encoding='utf8') as f:
vocab = json.load(f)
with open(merges_path, 'r', encoding='utf8') as f:
merges = f.read().split("\n")[1:-1]
self.tokenizer = TextTokenizer(vocab, merges)
def init_detokenizer(self):
print("initializing VQGanDetokenizer")
params = load_vqgan_torch_params('./pretrained/vqgan')
self.detokenizer = VQGanDetokenizer()
self.detokenizer.load_state_dict(params)
del params
def tokenize_text(self, text: str) -> numpy.ndarray:
print("tokenizing text")
tokens = self.tokenizer.tokenize(text)
print("text tokens", tokens)
text_token_count = self.config['max_text_length']
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens

View File

@ -1,79 +1,80 @@
import jax import jax
from jax import numpy as jnp
import numpy import numpy
from PIL import Image
import torch
from .min_dalle_base import MinDalleBase
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
from .load_params import load_dalle_bart_flax_params
def encode_flax(
text_tokens: numpy.ndarray,
config: dict,
params: dict
) -> jnp.ndarray:
print("loading flax encoder")
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = config['encoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['encoder_ffn_dim'],
text_token_count = config['max_text_length'],
text_vocab_count = config['encoder_vocab_size'],
layer_count = config['encoder_layers']
).bind({'params': params.pop('encoder')})
print("encoding text tokens")
encoder_state = encoder(text_tokens)
del encoder
return encoder_state
def decode_flax( class MinDalleFlax(MinDalleBase):
text_tokens: jnp.ndarray, def __init__(self, is_mega: bool, is_reusable: bool = True):
encoder_state: jnp.ndarray, super().__init__(is_mega)
config: dict, self.is_reusable = is_reusable
seed: int, print("initializing MinDalleFlax")
params: dict self.model_params = load_dalle_bart_flax_params(self.model_path)
) -> jnp.ndarray: if is_reusable:
print("loading flax decoder") self.init_encoder()
decoder = DalleBartDecoderFlax( self.init_decoder()
image_token_count = config['image_length'], self.init_detokenizer()
text_token_count = config['max_text_length'],
image_vocab_count = config['image_vocab_size'],
attention_head_count = config['decoder_attention_heads'], def init_encoder(self):
embed_count = config['d_model'], print("initializing DalleBartEncoderFlax")
glu_embed_count = config['decoder_ffn_dim'], self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
layer_count = config['decoder_layers'], attention_head_count = self.config['encoder_attention_heads'],
start_token = config['decoder_start_token_id'] embed_count = self.config['d_model'],
glu_embed_count = self.config['encoder_ffn_dim'],
text_token_count = self.config['max_text_length'],
text_vocab_count = self.config['encoder_vocab_size'],
layer_count = self.config['encoder_layers']
).bind({'params': self.model_params.pop('encoder')})
def init_decoder(self):
print("initializing DalleBartDecoderFlax")
self.decoder = DalleBartDecoderFlax(
image_token_count = self.config['image_length'],
text_token_count = self.config['max_text_length'],
image_vocab_count = self.config['image_vocab_size'],
attention_head_count = self.config['decoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
start_token = self.config['decoder_start_token_id']
) )
def generate_image(self, text: str, seed: int) -> Image.Image:
text_tokens = self.tokenize_text(text)
if not self.is_reusable: self.init_encoder()
print("encoding text tokens")
encoder_state = self.encoder(text_tokens)
if not self.is_reusable: del self.encoder
if not self.is_reusable:
self.init_decoder()
params = self.model_params.pop('decoder')
else:
params = self.model_params['decoder']
print("sampling image tokens") print("sampling image tokens")
image_tokens = decoder.sample_image_tokens( image_tokens = self.decoder.sample_image_tokens(
text_tokens, text_tokens,
encoder_state, encoder_state,
jax.random.PRNGKey(seed), jax.random.PRNGKey(seed),
params.pop('decoder')
)
del decoder
return image_tokens
def generate_image_tokens_flax(
text_tokens: numpy.ndarray,
seed: int,
config: dict,
params: dict
) -> numpy.ndarray:
encoder_state = encode_flax(
text_tokens,
config,
params params
) )
image_tokens = decode_flax( if not self.is_reusable: del self.decoder
text_tokens,
encoder_state, image_tokens = torch.tensor(numpy.array(image_tokens))
config,
seed, if not self.is_reusable: self.init_detokenizer()
params print("detokenizing image")
) image = self.detokenizer.forward(image_tokens).to(torch.uint8)
image_tokens = numpy.array(image_tokens) if not self.is_reusable: del self.detokenizer
print("image tokens", list(image_tokens)) image = Image.fromarray(image.to('cpu').detach().numpy())
return image_tokens return image

View File

@ -1,113 +1,114 @@
import numpy import os
from PIL import Image
from typing import Dict from typing import Dict
from torch import LongTensor, FloatTensor from torch import LongTensor
import torch import torch
torch.no_grad() torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
from .models.vqgan_detokenizer import VQGanDetokenizer from .load_params import (
convert_and_save_torch_params,
load_dalle_bart_flax_params
)
from .min_dalle_base import MinDalleBase
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
from .load_params import (
load_vqgan_torch_params, class MinDalleTorch(MinDalleBase):
convert_dalle_bart_torch_from_flax_params def __init__(
self,
is_mega: bool,
is_reusable: bool = True,
token_count: int = 256
):
print("initializing MinDalleTorch")
super().__init__(is_mega)
self.is_reusable = is_reusable
self.token_count = token_count
if not is_mega:
self.model_params = load_dalle_bart_flax_params(self.model_path)
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
is_converted = os.path.exists(self.encoder_params_path)
is_converted &= os.path.exists(self.decoder_params_path)
if not is_converted:
convert_and_save_torch_params(is_mega, self.model_path)
if is_reusable:
self.init_encoder()
self.init_decoder()
self.init_detokenizer()
def init_encoder(self):
print("initializing DalleBartEncoderTorch")
self.encoder = DalleBartEncoderTorch(
layer_count = self.config['encoder_layers'],
embed_count = self.config['d_model'],
attention_head_count = self.config['encoder_attention_heads'],
text_vocab_count = self.config['encoder_vocab_size'],
text_token_count = self.config['max_text_length'],
glu_embed_count = self.config['encoder_ffn_dim']
) )
params = torch.load(self.encoder_params_path)
self.encoder.load_state_dict(params, strict=False)
del params
if torch.cuda.is_available(): self.encoder = self.encoder.cuda()
def encode_torch( def init_decoder(self):
text_tokens: LongTensor, print("initializing DalleBartDecoderTorch")
config: dict, self.decoder = DalleBartDecoderTorch(
params: dict image_vocab_size = self.config['image_vocab_size'],
) -> FloatTensor: image_token_count = self.config['image_length'],
print("loading torch encoder") sample_token_count = self.token_count,
encoder = DalleBartEncoderTorch( embed_count = self.config['d_model'],
layer_count = config['encoder_layers'], attention_head_count = self.config['decoder_attention_heads'],
embed_count = config['d_model'], glu_embed_count = self.config['decoder_ffn_dim'],
attention_head_count = config['encoder_attention_heads'], layer_count = self.config['decoder_layers'],
text_vocab_count = config['encoder_vocab_size'],
text_token_count = config['max_text_length'],
glu_embed_count = config['encoder_ffn_dim']
)
encoder_params = convert_dalle_bart_torch_from_flax_params(
params.pop('encoder'),
layer_count=config['encoder_layers'],
is_encoder=True
)
encoder.load_state_dict(encoder_params, strict=False)
del encoder_params
print("encoding text tokens")
encoder_state = encoder(text_tokens)
del encoder
return encoder_state
def decode_torch(
text_tokens: LongTensor,
encoder_state: FloatTensor,
config: dict,
seed: int,
params: dict,
image_token_count: int
) -> LongTensor:
print("loading torch decoder")
decoder = DalleBartDecoderTorch(
image_vocab_size = config['image_vocab_size'],
image_token_count = config['image_length'],
sample_token_count = image_token_count,
embed_count = config['d_model'],
attention_head_count = config['decoder_attention_heads'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
batch_count = 2, batch_count = 2,
start_token = config['decoder_start_token_id'], start_token = self.config['decoder_start_token_id'],
is_verbose = True is_verbose = True
) )
decoder_params = convert_dalle_bart_torch_from_flax_params( params = torch.load(self.decoder_params_path)
params.pop('decoder'), self.decoder.load_state_dict(params, strict=False)
layer_count=config['decoder_layers'], del params
is_encoder=False if torch.cuda.is_available(): self.decoder = self.decoder.cuda()
)
decoder.load_state_dict(decoder_params, strict=False)
del decoder_params
print("sampling image tokens")
torch.manual_seed(seed)
image_tokens = decoder.forward(text_tokens, encoder_state)
return image_tokens
def generate_image_tokens_torch( def init_detokenizer(self):
text_tokens: numpy.ndarray, super().init_detokenizer()
seed: int, if torch.cuda.is_available():
config: dict, self.detokenizer = self.detokenizer.cuda()
params: dict,
image_token_count: int
) -> LongTensor: def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
text_tokens = self.tokenize_text(text)
text_tokens = torch.tensor(text_tokens).to(torch.long) text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda() if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
encoder_state = encode_torch(
text_tokens, if not self.is_reusable: self.init_encoder()
config, print("encoding text tokens")
params encoder_state = self.encoder.forward(text_tokens)
) if not self.is_reusable: del self.encoder
image_tokens = decode_torch(
text_tokens, if not self.is_reusable: self.init_decoder()
encoder_state, print("sampling image tokens")
config, torch.manual_seed(seed)
seed, image_tokens = self.decoder.forward(text_tokens, encoder_state)
params, if not self.is_reusable: del self.decoder
image_token_count
)
return image_tokens return image_tokens
def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray: def generate_image(self, text: str, seed: int) -> Image.Image:
image_tokens = self.generate_image_tokens(text, seed)
if not self.is_reusable: self.init_detokenizer()
print("detokenizing image") print("detokenizing image")
model_path = './pretrained/vqgan' image = self.detokenizer.forward(image_tokens).to(torch.uint8)
params = load_vqgan_torch_params(model_path) if not self.is_reusable: del self.detokenizer
detokenizer = VQGanDetokenizer() image = Image.fromarray(image.to('cpu').detach().numpy())
detokenizer.load_state_dict(params) return image
image = detokenizer.forward(image_tokens).to(torch.uint8)
return image.detach().numpy()

View File

@ -13,46 +13,39 @@ class DecoderCrossAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray, attention_mask: jnp.ndarray,
) -> jnp.ndarray: ) -> jnp.ndarray:
keys: jnp.ndarray = self.k_proj(encoder_state) keys = self.k_proj(encoder_state)
values: jnp.ndarray = self.v_proj(encoder_state) values = self.v_proj(encoder_state)
queries: jnp.ndarray = self.q_proj(decoder_state) queries = self.q_proj(decoder_state)
query_shape = queries.shape[:2] + (self.head_count, -1)
key_value_shape = keys.shape[:2] + (self.head_count, -1)
keys = keys.reshape(key_value_shape)
values = values.reshape(key_value_shape)
queries = queries.reshape(query_shape)
queries /= queries.shape[-1] ** 0.5
return self.forward(keys, values, queries, attention_mask) return self.forward(keys, values, queries, attention_mask)
class DecoderSelfAttentionFlax(AttentionFlax): class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(self, def __call__(
self,
decoder_state: jnp.ndarray, decoder_state: jnp.ndarray,
keys_state: jnp.ndarray, attention_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_mask: jnp.ndarray, attention_mask: jnp.ndarray,
state_index: tuple state_index: tuple
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: ) -> Tuple[jnp.ndarray, jnp.ndarray]:
shape_split = decoder_state.shape[:2] + (self.head_count, -1) keys = self.k_proj(decoder_state)
keys_state = lax.dynamic_update_slice( values = self.v_proj(decoder_state)
keys_state, queries = self.q_proj(decoder_state)
self.k_proj(decoder_state).reshape(shape_split),
attention_state = lax.dynamic_update_slice(
attention_state,
jnp.concatenate([keys, values]),
state_index state_index
) )
values_state = lax.dynamic_update_slice( batch_count = decoder_state.shape[0]
values_state, keys, values = attention_state[:batch_count], attention_state[batch_count:]
self.v_proj(decoder_state).reshape(shape_split),
state_index
)
queries = self.q_proj(decoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
decoder_state = self.forward( decoder_state = self.forward(
keys_state, keys,
values_state, values,
queries, queries,
attention_mask attention_mask
) )
return decoder_state, (keys_state, values_state) return decoder_state, attention_state
class DalleBartDecoderLayerFlax(nn.Module): class DalleBartDecoderLayerFlax(nn.Module):
@ -77,14 +70,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count) self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact @nn.compact
def __call__(self, def __call__(
self,
decoder_state: jnp.ndarray, decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
keys_state: jnp.ndarray, attention_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_mask: jnp.ndarray, attention_mask: jnp.ndarray,
token_index: int token_index: int
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: ) -> Tuple[jnp.ndarray, jnp.ndarray]:
# Self Attention # Self Attention
residual = decoder_state residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm(decoder_state) decoder_state = self.pre_self_attn_layer_norm(decoder_state)
@ -92,12 +85,11 @@ class DalleBartDecoderLayerFlax(nn.Module):
jnp.arange(self.image_token_count) < token_index + 1, jnp.arange(self.image_token_count) < token_index + 1,
(decoder_state.shape[0], 1) (decoder_state.shape[0], 1)
) )
decoder_state, keys_values_state = self.self_attn( decoder_state, attention_state = self.self_attn(
decoder_state, decoder_state,
keys_state, attention_state,
values_state,
self_attention_mask, self_attention_mask,
(0, token_index, 0, 0) (0, token_index, 0)
) )
decoder_state = self.self_attn_layer_norm(decoder_state) decoder_state = self.self_attn_layer_norm(decoder_state)
decoder_state = residual + decoder_state decoder_state = residual + decoder_state
@ -118,15 +110,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
decoder_state = self.glu(decoder_state) decoder_state = self.glu(decoder_state)
decoder_state = residual + decoder_state decoder_state = residual + decoder_state
return decoder_state, keys_values_state return decoder_state, attention_state
@flax.struct.dataclass @flax.struct.dataclass
class SampleState: class SampleState:
prev_token: jnp.ndarray prev_token: jnp.ndarray
prng_key: jnp.ndarray prng_key: jnp.ndarray
keys_state: jnp.ndarray attention_state: jnp.ndarray
values_state: jnp.ndarray
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray: def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
return a * logits[0, -1] + (1 - a) * logits[1, -1] return a * logits[0, -1] + (1 - a) * logits[1, -1]
@ -157,10 +148,10 @@ class DalleBartDecoderFlax(nn.Module):
) )
self.layers = nn.scan( self.layers = nn.scan(
DalleBartDecoderLayerFlax, DalleBartDecoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 }, variable_axes = { "params": 0 },
split_rngs = { "params": True }, split_rngs = { "params": True },
in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast), in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
out_axes = (0, 0), out_axes = 0,
length=self.layer_count, length=self.layer_count,
)( )(
self.image_token_count, self.image_token_count,
@ -173,32 +164,32 @@ class DalleBartDecoderFlax(nn.Module):
self.final_ln = nn.LayerNorm(use_scale=False) self.final_ln = nn.LayerNorm(use_scale=False)
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False) self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
def __call__(self, def __call__(
self,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
keys_state: jnp.ndarray, attention_state: jnp.ndarray,
values_state: jnp.ndarray,
attention_mask: jnp.ndarray, attention_mask: jnp.ndarray,
prev_token: int, prev_token: int,
token_index: int token_index: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ) -> Tuple[jnp.ndarray, jnp.ndarray]:
batch_count = encoder_state.shape[0] batch_count = encoder_state.shape[0]
ones = jnp.ones((batch_count, 1), dtype=jnp.int32) ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
decoder_state = self.embed_tokens(prev_token * ones) decoder_state = self.embed_tokens(prev_token * ones)
decoder_state += self.embed_positions(token_index * ones) decoder_state += self.embed_positions(token_index * ones)
decoder_state = self.layernorm_embedding(decoder_state) decoder_state = self.layernorm_embedding(decoder_state)
decoder_state, (keys_state, values_state) = self.layers( decoder_state, attention_state = self.layers(
decoder_state, decoder_state,
encoder_state, encoder_state,
keys_state, attention_state,
values_state,
attention_mask, attention_mask,
token_index token_index
) )
decoder_state = self.final_ln(decoder_state) decoder_state = self.final_ln(decoder_state)
decoder_state = self.lm_head(decoder_state) decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state return decoder_state, attention_state
def sample_image_tokens(self, def sample_image_tokens(
self,
text_tokens: jnp.ndarray, text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
prng_key: jax.random.PRNGKey, prng_key: jax.random.PRNGKey,
@ -209,12 +200,11 @@ class DalleBartDecoderFlax(nn.Module):
def sample_next_image_token( def sample_next_image_token(
state: SampleState, state: SampleState,
token_index: int token_index: int
) -> Tuple[SampleState, None]: ) -> Tuple[SampleState, jnp.ndarray]:
logits, keys_state, values_state = self.apply( logits, attention_state = self.apply(
{ 'params': params }, { 'params': params },
encoder_state = encoder_state, encoder_state = encoder_state,
keys_state = state.keys_state, attention_state = state.attention_state,
values_state = state.values_state,
attention_mask = attention_mask, attention_mask = attention_mask,
prev_token = state.prev_token, prev_token = state.prev_token,
token_index = token_index token_index = token_index
@ -229,26 +219,23 @@ class DalleBartDecoderFlax(nn.Module):
state = SampleState( state = SampleState(
prev_token = next_token, prev_token = next_token,
prng_key = prng_key_next, prng_key = prng_key_next,
keys_state = keys_state, attention_state = attention_state
values_state = values_state
) )
return state, next_token return state, next_token
batch_count = encoder_state.shape[0] batch_count = encoder_state.shape[0]
state_shape = ( attention_state_shape = (
self.layer_count, self.layer_count,
batch_count, batch_count * 2,
self.image_token_count, self.image_token_count,
self.attention_head_count, self.embed_count
self.embed_count // self.attention_head_count
) )
initial_state = SampleState( initial_state = SampleState(
prev_token = self.start_token, prev_token = self.start_token,
prng_key = prng_key, prng_key = prng_key,
keys_state = jnp.zeros(state_shape), attention_state = jnp.zeros(attention_state_shape)
values_state = jnp.zeros(state_shape)
) )
_, image_tokens = lax.scan( _, image_tokens = lax.scan(

View File

@ -1,7 +1,7 @@
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from torch import LongTensor, nn, FloatTensor, BoolTensor from torch import LongTensor, nn, FloatTensor, BoolTensor
torch.no_grad() torch.set_grad_enabled(False)
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
@ -16,42 +16,35 @@ class DecoderCrossAttentionTorch(AttentionTorch):
keys = self.k_proj.forward(encoder_state) keys = self.k_proj.forward(encoder_state)
values = self.v_proj.forward(encoder_state) values = self.v_proj.forward(encoder_state)
queries = self.q_proj.forward(decoder_state) queries = self.q_proj.forward(decoder_state)
query_shape = queries.shape[:2] + (self.head_count, -1)
key_value_shape = keys.shape[:2] + (self.head_count, -1)
keys = keys.reshape(key_value_shape)
values = values.reshape(key_value_shape)
queries = queries.reshape(query_shape)
queries /= queries.shape[-1] ** 0.5
return super().forward(keys, values, queries, attention_mask) return super().forward(keys, values, queries, attention_mask)
class DecoderSelfAttentionTorch(AttentionTorch): class DecoderSelfAttentionTorch(AttentionTorch):
def forward(self, def forward(
self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
keys_values: FloatTensor, attention_state: FloatTensor,
attention_mask: BoolTensor, attention_mask: BoolTensor,
token_index: LongTensor token_mask: BoolTensor
) -> Tuple[FloatTensor, FloatTensor]: ) -> Tuple[FloatTensor, FloatTensor]:
batch_count = decoder_state.shape[0] keys = self.k_proj.forward(decoder_state)
token_count = keys_values.shape[1] values = self.v_proj.forward(decoder_state)
shape = (batch_count, 1) + keys_values.shape[2:] queries = self.q_proj.forward(decoder_state)
keys = self.k_proj.forward(decoder_state).view(shape) attention_state = torch.where(
values = self.v_proj.forward(decoder_state).view(shape) token_mask[None, :, None],
token_mask = torch.arange(token_count) == token_index
keys_values = torch.where(
token_mask[None, :, None, None],
torch.cat([keys, values]), torch.cat([keys, values]),
keys_values attention_state
) )
queries = self.q_proj.forward(decoder_state).reshape(shape) batch_count = decoder_state.shape[0]
queries /= queries.shape[-1] ** 0.5 keys = attention_state[:batch_count]
keys, values = keys_values[:batch_count], keys_values[batch_count:] values = attention_state[batch_count:]
decoder_state = super().forward(keys, values, queries, attention_mask) decoder_state = super().forward(keys, values, queries, attention_mask)
return decoder_state, keys_values return decoder_state, attention_state
class DecoderLayerTorch(nn.Module): class DecoderLayerTorch(nn.Module):
def __init__(self, def __init__(
self,
image_token_count: int, image_token_count: int,
head_count: int, head_count: int,
embed_count: int, embed_count: int,
@ -67,23 +60,29 @@ class DecoderLayerTorch(nn.Module):
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.glu = GLUTorch(embed_count, glu_embed_count) self.glu = GLUTorch(embed_count, glu_embed_count)
def forward(self, self.token_indices = torch.arange(self.image_token_count)
if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda()
def forward(
self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
encoder_state: FloatTensor, encoder_state: FloatTensor,
keys_values_state: FloatTensor, attention_state: FloatTensor,
attention_mask: BoolTensor, attention_mask: BoolTensor,
token_index: LongTensor token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]: ) -> Tuple[FloatTensor, FloatTensor]:
# Self Attention # Self Attention
residual = decoder_state residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
self_attn_mask = torch.arange(self.image_token_count) < token_index + 1 self_attn_mask = self.token_indices < token_index + 1
token_mask = self.token_indices == token_index
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0]) self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
decoder_state, keys_values_state = self.self_attn.forward( decoder_state, attention_state = self.self_attn.forward(
decoder_state, decoder_state,
keys_values_state, attention_state,
self_attn_mask, self_attn_mask,
token_index token_mask
) )
decoder_state = self.self_attn_layer_norm.forward(decoder_state) decoder_state = self.self_attn_layer_norm.forward(decoder_state)
decoder_state = residual + decoder_state decoder_state = residual + decoder_state
@ -104,11 +103,12 @@ class DecoderLayerTorch(nn.Module):
decoder_state = self.glu.forward(decoder_state) decoder_state = self.glu.forward(decoder_state)
decoder_state = residual + decoder_state decoder_state = residual + decoder_state
return decoder_state, keys_values_state return decoder_state, attention_state
class DalleBartDecoderTorch(nn.Module): class DalleBartDecoderTorch(nn.Module):
def __init__(self, def __init__(
self,
image_vocab_size: int, image_vocab_size: int,
image_token_count: int, image_token_count: int,
sample_token_count: int, sample_token_count: int,
@ -124,13 +124,7 @@ class DalleBartDecoderTorch(nn.Module):
self.is_verbose = is_verbose self.is_verbose = is_verbose
self.layer_count = layer_count self.layer_count = layer_count
self.sample_token_count = sample_token_count self.sample_token_count = sample_token_count
self.start_token = torch.tensor([start_token]).to(torch.long) self.condition_factor = 10.0
self.pad_token = torch.tensor([1]).to(torch.long)
self.condition_factor = torch.tensor([10]).to(torch.float)
if torch.cuda.is_available():
self.start_token = self.start_token.cuda()
self.pad_token = self.pad_token.cuda()
self.condition_factor = self.condition_factor.cuda()
self.image_token_count = image_token_count self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count) self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count)
@ -146,77 +140,82 @@ class DalleBartDecoderTorch(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_count) self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count) self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False) self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
self.keys_values_state_shape = ( self.attention_state_shape = (
layer_count * 2 * batch_count, layer_count,
2 * batch_count,
image_token_count, image_token_count,
attention_head_count, embed_count
embed_count // attention_head_count
) )
self.zero_prob = torch.zeros([1])
self.token_indices = torch.arange(self.sample_token_count)
self.start_token = torch.tensor([start_token]).to(torch.long)
if torch.cuda.is_available():
self.zero_prob = self.zero_prob.cuda()
self.token_indices = self.token_indices.cuda()
self.start_token = self.start_token.cuda()
def decode_step(self, def decode_step(
self,
text_tokens: LongTensor, text_tokens: LongTensor,
encoder_state: FloatTensor, encoder_state: FloatTensor,
keys_values_state: FloatTensor, attention_state: FloatTensor,
prev_token_and_index: LongTensor prev_token: LongTensor,
token_index: LongTensor
) -> Tuple[LongTensor, FloatTensor]: ) -> Tuple[LongTensor, FloatTensor]:
attention_mask = text_tokens.not_equal(self.pad_token) attention_mask = text_tokens.not_equal(1)
batch_count = encoder_state.shape[0] batch_count = encoder_state.shape[0]
prev_token = torch.cat([prev_token_and_index[:1]] * batch_count) prev_token_batched = torch.cat([prev_token] * batch_count)
token_index = torch.cat([prev_token_and_index[1:]] * batch_count) token_index_batched = torch.cat([token_index] * batch_count)
decoder_state = self.embed_tokens.forward(prev_token) decoder_state = self.embed_tokens.forward(prev_token_batched)
decoder_state += self.embed_positions.forward(token_index) decoder_state += self.embed_positions.forward(token_index_batched)
decoder_state = self.layernorm_embedding.forward(decoder_state) decoder_state = self.layernorm_embedding.forward(decoder_state)
decoder_state = decoder_state[:, None] decoder_state = decoder_state[:, None]
keys_values = [] attention_states_new = []
for i, layer in enumerate(self.layers): for i in range(self.layer_count):
j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count decoder_state, attention_state_layer = self.layers[i].forward(
decoder_state, keys_values_layer = layer.forward(
decoder_state, decoder_state,
encoder_state, encoder_state,
keys_values_state[j1:j2], attention_state[i],
attention_mask, attention_mask,
token_index[:1] token_index
) )
keys_values.append(keys_values_layer) attention_states_new.append(attention_state_layer)
keys_values = torch.cat(keys_values, dim=0)
decoder_state = self.final_ln(decoder_state) decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state) logits = self.lm_head(decoder_state)
a = self.condition_factor a = self.condition_factor
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1] logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
top_logits = logits.sort(descending=True)[0][:50] top_logits, _ = logits.topk(50, dim=-1)
probs = torch.where( probs = torch.where(
logits < top_logits[-1], logits < top_logits[-1],
torch.zeros([1]), self.zero_prob,
torch.exp(logits - top_logits[0]) torch.exp(logits - top_logits[0])
) )
return probs, keys_values return probs, torch.stack(attention_states_new)
def forward(self, def forward(
self,
text_tokens: LongTensor, text_tokens: LongTensor,
encoder_state: FloatTensor encoder_state: FloatTensor
) -> LongTensor: ) -> LongTensor:
image_tokens: List[LongTensor] = [] image_tokens: List[LongTensor] = []
keys_values_state = torch.zeros(self.keys_values_state_shape) attention_state = torch.zeros(self.attention_state_shape)
if torch.cuda.is_available():
attention_state = attention_state.cuda()
image_token = self.start_token image_token = self.start_token
for i in range(self.sample_token_count): for i in range(self.sample_token_count):
token_index = torch.tensor([i]).to(torch.long) probs, attention_state = self.decode_step(
if torch.cuda.is_available(): token_index = token_index.cuda()
probs, keys_values_state = self.decode_step(
text_tokens = text_tokens, text_tokens = text_tokens,
encoder_state = encoder_state, encoder_state = encoder_state,
keys_values_state = keys_values_state, attention_state = attention_state,
prev_token_and_index = torch.cat([image_token, token_index]) prev_token = image_token,
token_index = self.token_indices[[i]]
) )
image_token = torch.multinomial(probs, 1) image_token = torch.multinomial(probs, 1)
image_tokens += [image_token] image_tokens += [image_token]
if self.is_verbose:
token = int(image_token.detach().numpy())
print("image token {} is {}".format(i, token))
return torch.cat(image_tokens) return torch.cat(image_tokens)

View File

@ -34,12 +34,17 @@ class AttentionFlax(nn.Module):
self.v_proj = nn.Dense(self.embed_count, use_bias=False) self.v_proj = nn.Dense(self.embed_count, use_bias=False)
self.out_proj = nn.Dense(self.embed_count, use_bias=False) self.out_proj = nn.Dense(self.embed_count, use_bias=False)
def forward(self, def forward(
self,
keys: jnp.ndarray, keys: jnp.ndarray,
values: jnp.ndarray, values: jnp.ndarray,
queries: jnp.ndarray, queries: jnp.ndarray,
attention_mask: jnp.ndarray attention_mask: jnp.ndarray
) -> jnp.ndarray: ) -> jnp.ndarray:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias: jnp.ndarray = lax.select( attention_bias: jnp.ndarray = lax.select(
attention_mask, attention_mask,
jnp.full(attention_mask.shape, 0.0), jnp.full(attention_mask.shape, 0.0),
@ -69,11 +74,9 @@ class EncoderSelfAttentionFlax(AttentionFlax):
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray attention_mask: jnp.ndarray
) -> jnp.ndarray: ) -> jnp.ndarray:
shape_split = encoder_state.shape[:2] + (self.head_count, -1) keys = self.k_proj(encoder_state)
keys = self.k_proj(encoder_state).reshape(shape_split) values = self.v_proj(encoder_state)
values = self.v_proj(encoder_state).reshape(shape_split) queries = self.q_proj(encoder_state)
queries = self.q_proj(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
return self.forward(keys, values, queries, attention_mask) return self.forward(keys, values, queries, attention_mask)
@ -92,7 +95,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count) self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact @nn.compact
def __call__(self, def __call__(
self,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray attention_mask: jnp.ndarray
) -> jnp.ndarray: ) -> jnp.ndarray:
@ -120,7 +124,7 @@ class DalleBartEncoderFlax(nn.Module):
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count) self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
self.layers = nn.scan( self.layers = nn.scan(
DalleBartEncoderLayerFlax, DalleBartEncoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 }, variable_axes = { "params": 0 },
split_rngs = { "params": True }, split_rngs = { "params": True },
in_axes = nn.broadcast, in_axes = nn.broadcast,
length = self.layer_count length = self.layer_count

View File

@ -1,7 +1,7 @@
from typing import List from typing import List
import torch import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.no_grad() torch.set_grad_enabled(False)
class GLUTorch(nn.Module): class GLUTorch(nn.Module):
@ -34,17 +34,25 @@ class AttentionTorch(nn.Module):
self.v_proj = nn.Linear(embed_count, embed_count, bias=False) self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
self.q_proj = nn.Linear(embed_count, embed_count, bias=False) self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
self.out_proj = nn.Linear(embed_count, embed_count, bias=False) self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
self.one = torch.ones((1, 1))
if torch.cuda.is_available(): self.one = self.one.cuda()
def forward(self, def forward(
self,
keys: FloatTensor, keys: FloatTensor,
values: FloatTensor, values: FloatTensor,
queries: FloatTensor, queries: FloatTensor,
attention_mask: BoolTensor attention_mask: BoolTensor
) -> FloatTensor: ) -> FloatTensor:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias = torch.where( attention_bias = torch.where(
attention_mask, attention_mask,
torch.full(attention_mask.shape, 0.0), self.one * 0,
torch.full(attention_mask.shape, -torch.inf), self.one * (-torch.inf),
) )
attention_weights: FloatTensor = torch.einsum( attention_weights: FloatTensor = torch.einsum(
'bqhc,bkhc->bhqk', 'bqhc,bkhc->bhqk',
@ -70,11 +78,9 @@ class EncoderSelfAttentionTorch(AttentionTorch):
encoder_state: FloatTensor, encoder_state: FloatTensor,
attention_mask: BoolTensor attention_mask: BoolTensor
) -> FloatTensor: ) -> FloatTensor:
shape_split = encoder_state.shape[:2] + (self.head_count, -1) keys = self.k_proj.forward(encoder_state)
keys = self.k_proj.forward(encoder_state).reshape(shape_split) values = self.v_proj.forward(encoder_state)
values = self.v_proj.forward(encoder_state).reshape(shape_split) queries = self.q_proj.forward(encoder_state)
queries = self.q_proj.forward(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
return super().forward(keys, values, queries, attention_mask) return super().forward(keys, values, queries, attention_mask)
@ -103,7 +109,8 @@ class EncoderLayerTorch(nn.Module):
class DalleBartEncoderTorch(nn.Module): class DalleBartEncoderTorch(nn.Module):
def __init__(self, def __init__(
self,
layer_count: int, layer_count: int,
embed_count: int, embed_count: int,
attention_head_count: int, attention_head_count: int,
@ -124,11 +131,14 @@ class DalleBartEncoderTorch(nn.Module):
]) ])
self.layernorm_embedding = nn.LayerNorm(embed_count) self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count) self.final_ln = nn.LayerNorm(embed_count)
self.token_indices = torch.arange(text_token_count).to(torch.long)
if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda()
def forward(self, text_tokens: LongTensor) -> FloatTensor: def forward(self, text_tokens: LongTensor) -> FloatTensor:
attention_mask = text_tokens.not_equal(1) attention_mask = text_tokens.not_equal(1)
batch_count, token_count = text_tokens.shape batch_count = text_tokens.shape[0]
pose_tokens = torch.stack([torch.arange(token_count)] * batch_count) pose_tokens = torch.stack([self.token_indices] * batch_count)
encoder_state = ( encoder_state = (
self.embed_tokens.forward(text_tokens) + self.embed_tokens.forward(text_tokens) +
self.embed_positions.forward(pose_tokens) self.embed_positions.forward(pose_tokens)

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.no_grad() torch.set_grad_enabled(False)
BATCH_COUNT: int = 1 BATCH_COUNT: int = 1
@ -61,6 +61,7 @@ class AttentionBlock(Module):
h = self.proj_out.forward(h) h = self.proj_out.forward(h)
return x + h return x + h
class MiddleLayer(Module): class MiddleLayer(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -74,6 +75,7 @@ class MiddleLayer(Module):
h = self.block_2.forward(h) h = self.block_2.forward(h)
return h return h
class Upsample(Module): class Upsample(Module):
def __init__(self, log2_count): def __init__(self, log2_count):
super().__init__() super().__init__()
@ -86,6 +88,7 @@ class Upsample(Module):
x = self.conv.forward(x) x = self.conv.forward(x)
return x return x
class UpsampleBlock(Module): class UpsampleBlock(Module):
def __init__( def __init__(
self, self,
@ -124,6 +127,7 @@ class UpsampleBlock(Module):
h = self.upsample.forward(h) h = self.upsample.forward(h)
return h return h
class Decoder(Module): class Decoder(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -154,6 +158,7 @@ class Decoder(Module):
z = self.conv_out.forward(z) z = self.conv_out.forward(z)
return z return z
class VQGanDetokenizer(Module): class VQGanDetokenizer(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -8,7 +8,7 @@ class TextTokenizer:
pairs = [tuple(pair.split()) for pair in merges] pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
def __call__(self, text: str) -> List[int]: def tokenize(self, text: str) -> List[int]:
sep_token = self.token_from_subword['</s>'] sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>'] cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>'] unk_token = self.token_from_subword['<unk>']

23
predict.py Normal file
View File

@ -0,0 +1,23 @@
import tempfile
from cog import BasePredictor, Path, Input
from min_dalle.min_dalle_torch import MinDalleTorch
class Predictor(BasePredictor):
def setup(self):
self.model = MinDalleTorch(is_mega=True)
def predict(
self,
text: str = Input(
description="Text for generating images.",
),
seed: int = Input(
description="Specify the seed.",
),
) -> Path:
image = self.model.generate_image(text, seed)
out_path = Path(tempfile.mkdtemp()) / "output.png"
image.save(str(out_path))
return out_path

1
requirements.txt vendored
View File

@ -1,2 +1,3 @@
torch torch
flax==0.4.2 flax==0.4.2
wandb

10
setup.sh vendored
View File

@ -1,15 +1,15 @@
#!/bin/bash #!/bin/bash
set -e
pip install -r requirements.txt pip install -r requirements.txt
mkdir -p pretrained mkdir -p pretrained/vqgan
# download vqgan # download vqgan
git lfs install curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output ./pretrained/vqgan/flax_model.msgpack
git clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384 ./pretrained/vqgan
# download dalle-mini and dalle mega # download dalle-mini and dalle mega
pip install wandb python -m wandb login --anonymously
python -m wandb login
python -m wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0 python -m wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
python -m wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14 python -m wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14