Merge branch 'kuprel:main' into replicate

This commit is contained in:
Chenxi 2022-06-29 19:50:47 +01:00 committed by GitHub
commit fcc17c895d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 325 additions and 330 deletions

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
* linguist-vendored
*.py linguist-vendored=false

16
README.md vendored
View File

@ -3,27 +3,31 @@
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) \ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) \
Try Replicate web demo here [![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle) Try Replicate web demo here [![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model. This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).
DALL·E Mega inference with PyTorch takes 7.3 seconds in Colab to generate an avocado armchair
### Setup ### Setup
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually: Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually here:
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384), [VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files), [DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files) [DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
### Usage ### Usage
Use the command line python script `image_from_text.py` to generate images. Here are some examples: Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize either `MinDalleTorch` or `MinDalleFlax`, then call `generate_image` with some text and a seed. See the colab for an example.
### Examples
``` ```
python image_from_text.py --text='alien life' --seed=7 python image_from_text.py --text='artificial intelligence' --torch
``` ```
![Alien](examples/alien.png) ![Alien](examples/artificial_intelligence.png)
``` ```
python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=4 python image_from_text.py --text='a comfy chair that looks like an avocado' --torch --mega --seed=10
``` ```
![Avocado Armchair](examples/avocado_armchair.png) ![Avocado Armchair](examples/avocado_armchair.png)

BIN
examples/alien.png vendored

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

BIN
examples/artificial_intelligence.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 90 KiB

After

Width:  |  Height:  |  Size: 101 KiB

View File

@ -2,8 +2,8 @@ import argparse
import os import os
from PIL import Image from PIL import Image
from min_dalle.generate_image import generate_image_from_text from min_dalle.min_dalle_torch import MinDalleTorch
from min_dalle.min_dalle_flax import MinDalleFlax
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--mega', action='store_true') parser.add_argument('--mega', action='store_true')
@ -12,10 +12,10 @@ parser.set_defaults(mega=False)
parser.add_argument('--torch', action='store_true') parser.add_argument('--torch', action='store_true')
parser.add_argument('--no-torch', dest='torch', action='store_false') parser.add_argument('--no-torch', dest='torch', action='store_false')
parser.set_defaults(torch=False) parser.set_defaults(torch=False)
parser.add_argument('--text', type=str) parser.add_argument('--text', type=str, default='alien life')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--image_path', type=str, default='generated') parser.add_argument('--image_path', type=str, default='generated')
parser.add_argument('--image_token_count', type=int, default=256) # for debugging parser.add_argument('--sample_token_count', type=int, default=256) # for debugging
def ascii_from_image(image: Image.Image, size: int) -> str: def ascii_from_image(image: Image.Image, size: int) -> str:
@ -36,19 +36,40 @@ def save_image(image: Image.Image, path: str):
return image return image
def generate_image(
is_torch: bool,
is_mega: bool,
text: str,
seed: int,
image_path: str,
sample_token_count: int
):
if is_torch:
image_generator = MinDalleTorch(is_mega, sample_token_count)
image_tokens = image_generator.generate_image_tokens(text, seed)
if sample_token_count < image_generator.config['image_length']:
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
return
else:
image = image_generator.generate_image(text, seed)
else:
image_generator = MinDalleFlax(is_mega)
image = image_generator.generate_image(text, seed)
save_image(image, image_path)
print(ascii_from_image(image, size=128))
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
generate_image(
image = generate_image_from_text(
text = args.text,
is_mega = args.mega,
is_torch=args.torch, is_torch=args.torch,
is_mega=args.mega,
text=args.text,
seed=args.seed, seed=args.seed,
image_token_count = args.image_token_count image_path=args.image_path,
sample_token_count=args.sample_token_count
) )
if image != None:
save_image(image, args.image_path)
print(ascii_from_image(image, size=128))

92
min_dalle.ipynb vendored

File diff suppressed because one or more lines are too long

View File

@ -1,78 +0,0 @@
import os
import json
import numpy
from PIL import Image
from typing import Tuple, List
import torch
from min_dalle.load_params import load_dalle_bart_flax_params
from min_dalle.text_tokenizer import TextTokenizer
from min_dalle.min_dalle_flax import generate_image_tokens_flax
from min_dalle.min_dalle_torch import (
generate_image_tokens_torch,
detokenize_torch
)
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
print("parsing metadata from {}".format(path))
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
assert(os.path.exists(os.path.join(path, f)))
with open(path + '/config.json', 'r') as f:
config = json.load(f)
with open(path + '/vocab.json') as f:
vocab = json.load(f)
with open(path + '/merges.txt') as f:
merges = f.read().split("\n")[1:-1]
return config, vocab, merges
def tokenize_text(
text: str,
config: dict,
vocab: dict,
merges: List[str]
) -> numpy.ndarray:
print("tokenizing text")
tokens = TextTokenizer(vocab, merges)(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens
def generate_image_from_text(
text: str,
is_mega: bool = False,
is_torch: bool = False,
seed: int = 0,
image_token_count: int = 256
) -> Image.Image:
model_name = 'mega' if is_mega else 'mini'
model_path = './pretrained/dalle_bart_{}'.format(model_name)
config, vocab, merges = load_dalle_bart_metadata(model_path)
text_tokens = tokenize_text(text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(model_path)
if is_torch:
image_tokens = generate_image_tokens_torch(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
image_token_count = image_token_count
)
if image_token_count == config['image_length']:
image = detokenize_torch(image_tokens, is_torch=True)
return Image.fromarray(image)
else:
print(list(image_tokens.to('cpu').detach().numpy()))
else:
image_tokens = generate_image_tokens_flax(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
)
image = detokenize_torch(torch.tensor(image_tokens), is_torch=False)
return Image.fromarray(image)

43
min_dalle/min_dalle.py Normal file
View File

@ -0,0 +1,43 @@
import os
import json
import numpy
from .text_tokenizer import TextTokenizer
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
from .models.vqgan_detokenizer import VQGanDetokenizer
class MinDalle:
def __init__(self, is_mega: bool):
self.is_mega = is_mega
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(model_path))
config_path = os.path.join(model_path, 'config.json')
vocab_path = os.path.join(model_path, 'vocab.json')
merges_path = os.path.join(model_path, 'merges.txt')
with open(config_path, 'r', encoding='utf8') as f:
self.config = json.load(f)
with open(vocab_path, 'r', encoding='utf8') as f:
vocab = json.load(f)
with open(merges_path, 'r', encoding='utf8') as f:
merges = f.read().split("\n")[1:-1]
self.model_params = load_dalle_bart_flax_params(model_path)
self.tokenizer = TextTokenizer(vocab, merges)
self.detokenizer = VQGanDetokenizer()
vqgan_params = load_vqgan_torch_params('./pretrained/vqgan')
self.detokenizer.load_state_dict(vqgan_params)
def tokenize_text(self, text: str) -> numpy.ndarray:
print("tokenizing text")
tokens = self.tokenizer.tokenize(text)
print("text tokens", tokens)
text_token_count = self.config['max_text_length']
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens

View File

@ -1,79 +1,58 @@
import jax import jax
from jax import numpy as jnp
import numpy import numpy
from PIL import Image
import torch
from .min_dalle import MinDalle
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
def encode_flax( class MinDalleFlax(MinDalle):
text_tokens: numpy.ndarray, def __init__(self, is_mega: bool):
config: dict, super().__init__(is_mega)
params: dict print("initializing MinDalleFlax")
) -> jnp.ndarray:
print("loading flax encoder") print("loading encoder")
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax( self.encoder = DalleBartEncoderFlax(
attention_head_count = config['encoder_attention_heads'], attention_head_count = self.config['encoder_attention_heads'],
embed_count = config['d_model'], embed_count = self.config['d_model'],
glu_embed_count = config['encoder_ffn_dim'], glu_embed_count = self.config['encoder_ffn_dim'],
text_token_count = config['max_text_length'], text_token_count = self.config['max_text_length'],
text_vocab_count = config['encoder_vocab_size'], text_vocab_count = self.config['encoder_vocab_size'],
layer_count = config['encoder_layers'] layer_count = self.config['encoder_layers']
).bind({'params': params.pop('encoder')}) ).bind({'params': self.model_params.pop('encoder')})
print("loading decoder")
self.decoder = DalleBartDecoderFlax(
image_token_count = self.config['image_length'],
text_token_count = self.config['max_text_length'],
image_vocab_count = self.config['image_vocab_size'],
attention_head_count = self.config['decoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
start_token = self.config['decoder_start_token_id']
)
def generate_image(self, text: str, seed: int) -> Image.Image:
text_tokens = self.tokenize_text(text)
print("encoding text tokens") print("encoding text tokens")
encoder_state = encoder(text_tokens) encoder_state = self.encoder(text_tokens)
del encoder
return encoder_state
def decode_flax(
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
config: dict,
seed: int,
params: dict
) -> jnp.ndarray:
print("loading flax decoder")
decoder = DalleBartDecoderFlax(
image_token_count = config['image_length'],
text_token_count = config['max_text_length'],
image_vocab_count = config['image_vocab_size'],
attention_head_count = config['decoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
start_token = config['decoder_start_token_id']
)
print("sampling image tokens") print("sampling image tokens")
image_tokens = decoder.sample_image_tokens( image_tokens = self.decoder.sample_image_tokens(
text_tokens, text_tokens,
encoder_state, encoder_state,
jax.random.PRNGKey(seed), jax.random.PRNGKey(seed),
params.pop('decoder') self.model_params['decoder']
) )
del decoder
return image_tokens
image_tokens = torch.tensor(numpy.array(image_tokens))
def generate_image_tokens_flax( print("detokenizing image")
text_tokens: numpy.ndarray, image = self.detokenizer.forward(image_tokens).to(torch.uint8)
seed: int, image = Image.fromarray(image.to('cpu').detach().numpy())
config: dict, return image
params: dict
) -> numpy.ndarray:
encoder_state = encode_flax(
text_tokens,
config,
params
)
image_tokens = decode_flax(
text_tokens,
encoder_state,
config,
seed,
params
)
image_tokens = numpy.array(image_tokens)
print("image tokens", list(image_tokens))
return image_tokens

View File

@ -1,118 +1,83 @@
from random import sample
import numpy import numpy
import os import os
from PIL import Image
from typing import Dict from typing import Dict
from torch import LongTensor, FloatTensor from torch import LongTensor
import torch import torch
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count()) torch.set_num_threads(os.cpu_count())
from .models.vqgan_detokenizer import VQGanDetokenizer from .load_params import convert_dalle_bart_torch_from_flax_params
from .min_dalle import MinDalle
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
from .load_params import (
load_vqgan_torch_params,
convert_dalle_bart_torch_from_flax_params
)
class MinDalleTorch(MinDalle):
def __init__(self, is_mega: bool, sample_token_count: int = 256):
super().__init__(is_mega)
print("initializing MinDalleTorch")
def encode_torch( print("loading encoder")
text_tokens: LongTensor, self.encoder = DalleBartEncoderTorch(
config: dict, layer_count = self.config['encoder_layers'],
params: dict embed_count = self.config['d_model'],
) -> FloatTensor: attention_head_count = self.config['encoder_attention_heads'],
print("loading torch encoder") text_vocab_count = self.config['encoder_vocab_size'],
encoder = DalleBartEncoderTorch( text_token_count = self.config['max_text_length'],
layer_count = config['encoder_layers'], glu_embed_count = self.config['encoder_ffn_dim']
embed_count = config['d_model'],
attention_head_count = config['encoder_attention_heads'],
text_vocab_count = config['encoder_vocab_size'],
text_token_count = config['max_text_length'],
glu_embed_count = config['encoder_ffn_dim']
) )
encoder_params = convert_dalle_bart_torch_from_flax_params( encoder_params = convert_dalle_bart_torch_from_flax_params(
params.pop('encoder'), self.model_params.pop('encoder'),
layer_count=config['encoder_layers'], layer_count=self.config['encoder_layers'],
is_encoder=True is_encoder=True
) )
encoder.load_state_dict(encoder_params, strict=False) self.encoder.load_state_dict(encoder_params, strict=False)
del encoder_params
if torch.cuda.is_available(): encoder = encoder.cuda()
print("encoding text tokens") print("loading decoder")
encoder_state = encoder(text_tokens) self.decoder = DalleBartDecoderTorch(
del encoder image_vocab_size = self.config['image_vocab_size'],
return encoder_state image_token_count = self.config['image_length'],
sample_token_count = sample_token_count,
embed_count = self.config['d_model'],
def decode_torch( attention_head_count = self.config['decoder_attention_heads'],
text_tokens: LongTensor, glu_embed_count = self.config['decoder_ffn_dim'],
encoder_state: FloatTensor, layer_count = self.config['decoder_layers'],
config: dict,
seed: int,
params: dict,
image_token_count: int
) -> LongTensor:
print("loading torch decoder")
decoder = DalleBartDecoderTorch(
image_vocab_size = config['image_vocab_size'],
image_token_count = config['image_length'],
sample_token_count = image_token_count,
embed_count = config['d_model'],
attention_head_count = config['decoder_attention_heads'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
batch_count = 2, batch_count = 2,
start_token = config['decoder_start_token_id'], start_token = self.config['decoder_start_token_id'],
is_verbose = True is_verbose = True
) )
decoder_params = convert_dalle_bart_torch_from_flax_params( decoder_params = convert_dalle_bart_torch_from_flax_params(
params.pop('decoder'), self.model_params.pop('decoder'),
layer_count=config['decoder_layers'], layer_count=self.config['decoder_layers'],
is_encoder=False is_encoder=False
) )
decoder.load_state_dict(decoder_params, strict=False) self.decoder.load_state_dict(decoder_params, strict=False)
del decoder_params
if torch.cuda.is_available(): decoder = decoder.cuda() if torch.cuda.is_available():
self.encoder = self.encoder.cuda()
self.decoder = self.decoder.cuda()
self.detokenizer = self.detokenizer.cuda()
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
text_tokens = self.tokenize_text(text)
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
print("encoding text tokens")
encoder_state = self.encoder.forward(text_tokens)
print("sampling image tokens") print("sampling image tokens")
torch.manual_seed(seed) torch.manual_seed(seed)
image_tokens = decoder.forward(text_tokens, encoder_state) image_tokens = self.decoder.forward(text_tokens, encoder_state)
return image_tokens return image_tokens
def generate_image_tokens_torch( def generate_image(self, text: str, seed: int) -> Image.Image:
text_tokens: numpy.ndarray, image_tokens = self.generate_image_tokens(text, seed)
seed: int,
config: dict,
params: dict,
image_token_count: int
) -> LongTensor:
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
encoder_state = encode_torch(
text_tokens,
config,
params
)
image_tokens = decode_torch(
text_tokens,
encoder_state,
config,
seed,
params,
image_token_count
)
return image_tokens
def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray:
print("detokenizing image") print("detokenizing image")
model_path = './pretrained/vqgan' image = self.detokenizer.forward(image_tokens).to(torch.uint8)
params = load_vqgan_torch_params(model_path) image = Image.fromarray(image.to('cpu').detach().numpy())
detokenizer = VQGanDetokenizer() return image
detokenizer.load_state_dict(params)
if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda()
image = detokenizer.forward(image_tokens).to(torch.uint8)
del detokenizer, params
return image.to('cpu').detach().numpy()

View File

@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax):
class DecoderSelfAttentionFlax(AttentionFlax): class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(self, def __call__(
self,
decoder_state: jnp.ndarray, decoder_state: jnp.ndarray,
keys_state: jnp.ndarray, keys_state: jnp.ndarray,
values_state: jnp.ndarray, values_state: jnp.ndarray,
@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count) self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact @nn.compact
def __call__(self, def __call__(
self,
decoder_state: jnp.ndarray, decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
keys_state: jnp.ndarray, keys_state: jnp.ndarray,
@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module):
self.final_ln = nn.LayerNorm(use_scale=False) self.final_ln = nn.LayerNorm(use_scale=False)
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False) self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
def __call__(self, def __call__(
self,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
keys_state: jnp.ndarray, keys_state: jnp.ndarray,
values_state: jnp.ndarray, values_state: jnp.ndarray,
@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module):
decoder_state = self.lm_head(decoder_state) decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state return decoder_state, keys_state, values_state
def sample_image_tokens(self, def sample_image_tokens(
self,
text_tokens: jnp.ndarray, text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
prng_key: jax.random.PRNGKey, prng_key: jax.random.PRNGKey,

View File

@ -16,40 +16,34 @@ class DecoderCrossAttentionTorch(AttentionTorch):
keys = self.k_proj.forward(encoder_state) keys = self.k_proj.forward(encoder_state)
values = self.v_proj.forward(encoder_state) values = self.v_proj.forward(encoder_state)
queries = self.q_proj.forward(decoder_state) queries = self.q_proj.forward(decoder_state)
query_shape = queries.shape[:2] + (self.head_count, -1)
key_value_shape = keys.shape[:2] + (self.head_count, -1)
keys = keys.reshape(key_value_shape)
values = values.reshape(key_value_shape)
queries = queries.reshape(query_shape)
queries /= queries.shape[-1] ** 0.5
return super().forward(keys, values, queries, attention_mask) return super().forward(keys, values, queries, attention_mask)
class DecoderSelfAttentionTorch(AttentionTorch): class DecoderSelfAttentionTorch(AttentionTorch):
def forward(self, def forward(
self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
keys_values: FloatTensor, keys_values: FloatTensor,
attention_mask: BoolTensor, attention_mask: BoolTensor,
token_mask: BoolTensor token_mask: BoolTensor
) -> Tuple[FloatTensor, FloatTensor]: ) -> Tuple[FloatTensor, FloatTensor]:
batch_count = decoder_state.shape[0] batch_count = decoder_state.shape[0]
shape = (batch_count, 1) + keys_values.shape[2:] keys = self.k_proj.forward(decoder_state)
keys = self.k_proj.forward(decoder_state).view(shape) values = self.v_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state).view(shape) queries = self.q_proj.forward(decoder_state)
keys_values = torch.where( keys_values = torch.where(
token_mask[None, :, None, None], token_mask[None, :, None],
torch.cat([keys, values]), torch.cat([keys, values]),
keys_values keys_values
) )
queries = self.q_proj.forward(decoder_state).reshape(shape)
queries /= queries.shape[-1] ** 0.5
keys, values = keys_values[:batch_count], keys_values[batch_count:] keys, values = keys_values[:batch_count], keys_values[batch_count:]
decoder_state = super().forward(keys, values, queries, attention_mask) decoder_state = super().forward(keys, values, queries, attention_mask)
return decoder_state, keys_values return decoder_state, keys_values
class DecoderLayerTorch(nn.Module): class DecoderLayerTorch(nn.Module):
def __init__(self, def __init__(
self,
image_token_count: int, image_token_count: int,
head_count: int, head_count: int,
embed_count: int, embed_count: int,
@ -69,7 +63,8 @@ class DecoderLayerTorch(nn.Module):
if torch.cuda.is_available(): if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda() self.token_indices = self.token_indices.cuda()
def forward(self, def forward(
self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
encoder_state: FloatTensor, encoder_state: FloatTensor,
keys_values_state: FloatTensor, keys_values_state: FloatTensor,
@ -111,7 +106,8 @@ class DecoderLayerTorch(nn.Module):
class DalleBartDecoderTorch(nn.Module): class DalleBartDecoderTorch(nn.Module):
def __init__(self, def __init__(
self,
image_vocab_size: int, image_vocab_size: int,
image_token_count: int, image_token_count: int,
sample_token_count: int, sample_token_count: int,
@ -146,8 +142,7 @@ class DalleBartDecoderTorch(nn.Module):
self.keys_values_state_shape = ( self.keys_values_state_shape = (
layer_count * 2 * batch_count, layer_count * 2 * batch_count,
image_token_count, image_token_count,
attention_head_count, embed_count
embed_count // attention_head_count
) )
self.zero_prob = torch.zeros([1]) self.zero_prob = torch.zeros([1])
self.token_indices = torch.arange(self.sample_token_count) self.token_indices = torch.arange(self.sample_token_count)
@ -158,7 +153,8 @@ class DalleBartDecoderTorch(nn.Module):
self.start_token = self.start_token.cuda() self.start_token = self.start_token.cuda()
def decode_step(self, def decode_step(
self,
text_tokens: LongTensor, text_tokens: LongTensor,
encoder_state: FloatTensor, encoder_state: FloatTensor,
keys_values_state: FloatTensor, keys_values_state: FloatTensor,
@ -183,7 +179,6 @@ class DalleBartDecoderTorch(nn.Module):
token_index[:1] token_index[:1]
) )
keys_values.append(keys_values_layer) keys_values.append(keys_values_layer)
keys_values = torch.cat(keys_values, dim=0)
decoder_state = self.final_ln(decoder_state) decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state) logits = self.lm_head(decoder_state)
a = self.condition_factor a = self.condition_factor
@ -195,10 +190,11 @@ class DalleBartDecoderTorch(nn.Module):
self.zero_prob, self.zero_prob,
torch.exp(logits - top_logits[0]) torch.exp(logits - top_logits[0])
) )
return probs, keys_values return probs, torch.cat(keys_values)
def forward(self, def forward(
self,
text_tokens: LongTensor, text_tokens: LongTensor,
encoder_state: FloatTensor encoder_state: FloatTensor
) -> LongTensor: ) -> LongTensor:

View File

@ -34,7 +34,8 @@ class AttentionFlax(nn.Module):
self.v_proj = nn.Dense(self.embed_count, use_bias=False) self.v_proj = nn.Dense(self.embed_count, use_bias=False)
self.out_proj = nn.Dense(self.embed_count, use_bias=False) self.out_proj = nn.Dense(self.embed_count, use_bias=False)
def forward(self, def forward(
self,
keys: jnp.ndarray, keys: jnp.ndarray,
values: jnp.ndarray, values: jnp.ndarray,
queries: jnp.ndarray, queries: jnp.ndarray,
@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count) self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact @nn.compact
def __call__(self, def __call__(
self,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray attention_mask: jnp.ndarray
) -> jnp.ndarray: ) -> jnp.ndarray:

View File

@ -37,12 +37,18 @@ class AttentionTorch(nn.Module):
self.one = torch.ones((1, 1)) self.one = torch.ones((1, 1))
if torch.cuda.is_available(): self.one = self.one.cuda() if torch.cuda.is_available(): self.one = self.one.cuda()
def forward(self, def forward(
self,
keys: FloatTensor, keys: FloatTensor,
values: FloatTensor, values: FloatTensor,
queries: FloatTensor, queries: FloatTensor,
attention_mask: BoolTensor attention_mask: BoolTensor
) -> FloatTensor: ) -> FloatTensor:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias = torch.where( attention_bias = torch.where(
attention_mask, attention_mask,
self.one * 0, self.one * 0,
@ -72,11 +78,9 @@ class EncoderSelfAttentionTorch(AttentionTorch):
encoder_state: FloatTensor, encoder_state: FloatTensor,
attention_mask: BoolTensor attention_mask: BoolTensor
) -> FloatTensor: ) -> FloatTensor:
shape_split = encoder_state.shape[:2] + (self.head_count, -1) keys = self.k_proj.forward(encoder_state)
keys = self.k_proj.forward(encoder_state).reshape(shape_split) values = self.v_proj.forward(encoder_state)
values = self.v_proj.forward(encoder_state).reshape(shape_split) queries = self.q_proj.forward(encoder_state)
queries = self.q_proj.forward(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
return super().forward(keys, values, queries, attention_mask) return super().forward(keys, values, queries, attention_mask)
@ -105,7 +109,8 @@ class EncoderLayerTorch(nn.Module):
class DalleBartEncoderTorch(nn.Module): class DalleBartEncoderTorch(nn.Module):
def __init__(self, def __init__(
self,
layer_count: int, layer_count: int,
embed_count: int, embed_count: int,
attention_head_count: int, attention_head_count: int,

View File

@ -8,7 +8,7 @@ class TextTokenizer:
pairs = [tuple(pair.split()) for pair in merges] pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
def __call__(self, text: str) -> List[int]: def tokenize(self, text: str) -> List[int]:
sep_token = self.token_from_subword['</s>'] sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>'] cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>'] unk_token = self.token_from_subword['<unk>']

2
requirements.txt vendored
View File

@ -1,3 +1,3 @@
torch torch
flax==0.4.2 flax==0.5.2
wandb wandb