Merge branch 'kuprel:main' into replicate

This commit is contained in:
Chenxi 2022-06-29 19:50:47 +01:00 committed by GitHub
commit fcc17c895d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 325 additions and 330 deletions

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
* linguist-vendored
*.py linguist-vendored=false

18
README.md vendored
View File

@ -3,27 +3,31 @@
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) \
Try Replicate web demo here [![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model.
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).
DALL·E Mega inference with PyTorch takes 7.3 seconds in Colab to generate an avocado armchair
### Setup
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually:
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually here:
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
### Usage
Use the command line python script `image_from_text.py` to generate images. Here are some examples:
Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize either `MinDalleTorch` or `MinDalleFlax`, then call `generate_image` with some text and a seed. See the colab for an example.
### Examples
```
python image_from_text.py --text='alien life' --seed=7
python image_from_text.py --text='artificial intelligence' --torch
```
![Alien](examples/alien.png)
![Alien](examples/artificial_intelligence.png)
```
python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=4
python image_from_text.py --text='a comfy chair that looks like an avocado' --torch --mega --seed=10
```
![Avocado Armchair](examples/avocado_armchair.png)
@ -32,4 +36,4 @@ python image_from_text.py --text='a comfy chair that looks like an avocado' --me
python image_from_text.py --text='court sketch of godzilla on trial' --mega --seed=100
```
![Godzilla Trial](examples/godzilla_trial.png)
![Godzilla Trial](examples/godzilla_trial.png)

BIN
examples/alien.png vendored

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

BIN
examples/artificial_intelligence.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 90 KiB

After

Width:  |  Height:  |  Size: 101 KiB

View File

@ -2,8 +2,8 @@ import argparse
import os
from PIL import Image
from min_dalle.generate_image import generate_image_from_text
from min_dalle.min_dalle_torch import MinDalleTorch
from min_dalle.min_dalle_flax import MinDalleFlax
parser = argparse.ArgumentParser()
parser.add_argument('--mega', action='store_true')
@ -12,10 +12,10 @@ parser.set_defaults(mega=False)
parser.add_argument('--torch', action='store_true')
parser.add_argument('--no-torch', dest='torch', action='store_false')
parser.set_defaults(torch=False)
parser.add_argument('--text', type=str)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--text', type=str, default='alien life')
parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--image_path', type=str, default='generated')
parser.add_argument('--image_token_count', type=int, default=256) # for debugging
parser.add_argument('--sample_token_count', type=int, default=256) # for debugging
def ascii_from_image(image: Image.Image, size: int) -> str:
@ -36,19 +36,40 @@ def save_image(image: Image.Image, path: str):
return image
def generate_image(
is_torch: bool,
is_mega: bool,
text: str,
seed: int,
image_path: str,
sample_token_count: int
):
if is_torch:
image_generator = MinDalleTorch(is_mega, sample_token_count)
image_tokens = image_generator.generate_image_tokens(text, seed)
if sample_token_count < image_generator.config['image_length']:
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
return
else:
image = image_generator.generate_image(text, seed)
else:
image_generator = MinDalleFlax(is_mega)
image = image_generator.generate_image(text, seed)
save_image(image, image_path)
print(ascii_from_image(image, size=128))
if __name__ == '__main__':
args = parser.parse_args()
print(args)
image = generate_image_from_text(
text = args.text,
is_mega = args.mega,
is_torch = args.torch,
seed = args.seed,
image_token_count = args.image_token_count
)
if image != None:
save_image(image, args.image_path)
print(ascii_from_image(image, size=128))
generate_image(
is_torch=args.torch,
is_mega=args.mega,
text=args.text,
seed=args.seed,
image_path=args.image_path,
sample_token_count=args.sample_token_count
)

92
min_dalle.ipynb vendored

File diff suppressed because one or more lines are too long

View File

@ -1,78 +0,0 @@
import os
import json
import numpy
from PIL import Image
from typing import Tuple, List
import torch
from min_dalle.load_params import load_dalle_bart_flax_params
from min_dalle.text_tokenizer import TextTokenizer
from min_dalle.min_dalle_flax import generate_image_tokens_flax
from min_dalle.min_dalle_torch import (
generate_image_tokens_torch,
detokenize_torch
)
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
print("parsing metadata from {}".format(path))
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
assert(os.path.exists(os.path.join(path, f)))
with open(path + '/config.json', 'r') as f:
config = json.load(f)
with open(path + '/vocab.json') as f:
vocab = json.load(f)
with open(path + '/merges.txt') as f:
merges = f.read().split("\n")[1:-1]
return config, vocab, merges
def tokenize_text(
text: str,
config: dict,
vocab: dict,
merges: List[str]
) -> numpy.ndarray:
print("tokenizing text")
tokens = TextTokenizer(vocab, merges)(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens
def generate_image_from_text(
text: str,
is_mega: bool = False,
is_torch: bool = False,
seed: int = 0,
image_token_count: int = 256
) -> Image.Image:
model_name = 'mega' if is_mega else 'mini'
model_path = './pretrained/dalle_bart_{}'.format(model_name)
config, vocab, merges = load_dalle_bart_metadata(model_path)
text_tokens = tokenize_text(text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(model_path)
if is_torch:
image_tokens = generate_image_tokens_torch(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
image_token_count = image_token_count
)
if image_token_count == config['image_length']:
image = detokenize_torch(image_tokens, is_torch=True)
return Image.fromarray(image)
else:
print(list(image_tokens.to('cpu').detach().numpy()))
else:
image_tokens = generate_image_tokens_flax(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
)
image = detokenize_torch(torch.tensor(image_tokens), is_torch=False)
return Image.fromarray(image)

43
min_dalle/min_dalle.py Normal file
View File

@ -0,0 +1,43 @@
import os
import json
import numpy
from .text_tokenizer import TextTokenizer
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
from .models.vqgan_detokenizer import VQGanDetokenizer
class MinDalle:
def __init__(self, is_mega: bool):
self.is_mega = is_mega
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(model_path))
config_path = os.path.join(model_path, 'config.json')
vocab_path = os.path.join(model_path, 'vocab.json')
merges_path = os.path.join(model_path, 'merges.txt')
with open(config_path, 'r', encoding='utf8') as f:
self.config = json.load(f)
with open(vocab_path, 'r', encoding='utf8') as f:
vocab = json.load(f)
with open(merges_path, 'r', encoding='utf8') as f:
merges = f.read().split("\n")[1:-1]
self.model_params = load_dalle_bart_flax_params(model_path)
self.tokenizer = TextTokenizer(vocab, merges)
self.detokenizer = VQGanDetokenizer()
vqgan_params = load_vqgan_torch_params('./pretrained/vqgan')
self.detokenizer.load_state_dict(vqgan_params)
def tokenize_text(self, text: str) -> numpy.ndarray:
print("tokenizing text")
tokens = self.tokenizer.tokenize(text)
print("text tokens", tokens)
text_token_count = self.config['max_text_length']
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens

View File

@ -1,79 +1,58 @@
import jax
from jax import numpy as jnp
import numpy
from PIL import Image
import torch
from .min_dalle import MinDalle
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
def encode_flax(
text_tokens: numpy.ndarray,
config: dict,
params: dict
) -> jnp.ndarray:
print("loading flax encoder")
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = config['encoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['encoder_ffn_dim'],
text_token_count = config['max_text_length'],
text_vocab_count = config['encoder_vocab_size'],
layer_count = config['encoder_layers']
).bind({'params': params.pop('encoder')})
class MinDalleFlax(MinDalle):
def __init__(self, is_mega: bool):
super().__init__(is_mega)
print("initializing MinDalleFlax")
print("encoding text tokens")
encoder_state = encoder(text_tokens)
del encoder
return encoder_state
print("loading encoder")
self.encoder = DalleBartEncoderFlax(
attention_head_count = self.config['encoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['encoder_ffn_dim'],
text_token_count = self.config['max_text_length'],
text_vocab_count = self.config['encoder_vocab_size'],
layer_count = self.config['encoder_layers']
).bind({'params': self.model_params.pop('encoder')})
print("loading decoder")
self.decoder = DalleBartDecoderFlax(
image_token_count = self.config['image_length'],
text_token_count = self.config['max_text_length'],
image_vocab_count = self.config['image_vocab_size'],
attention_head_count = self.config['decoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
start_token = self.config['decoder_start_token_id']
)
def decode_flax(
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
config: dict,
seed: int,
params: dict
) -> jnp.ndarray:
print("loading flax decoder")
decoder = DalleBartDecoderFlax(
image_token_count = config['image_length'],
text_token_count = config['max_text_length'],
image_vocab_count = config['image_vocab_size'],
attention_head_count = config['decoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
start_token = config['decoder_start_token_id']
)
print("sampling image tokens")
image_tokens = decoder.sample_image_tokens(
text_tokens,
encoder_state,
jax.random.PRNGKey(seed),
params.pop('decoder')
)
del decoder
return image_tokens
def generate_image(self, text: str, seed: int) -> Image.Image:
text_tokens = self.tokenize_text(text)
print("encoding text tokens")
encoder_state = self.encoder(text_tokens)
def generate_image_tokens_flax(
text_tokens: numpy.ndarray,
seed: int,
config: dict,
params: dict
) -> numpy.ndarray:
encoder_state = encode_flax(
text_tokens,
config,
params
)
image_tokens = decode_flax(
text_tokens,
encoder_state,
config,
seed,
params
)
image_tokens = numpy.array(image_tokens)
print("image tokens", list(image_tokens))
return image_tokens
print("sampling image tokens")
image_tokens = self.decoder.sample_image_tokens(
text_tokens,
encoder_state,
jax.random.PRNGKey(seed),
self.model_params['decoder']
)
image_tokens = torch.tensor(numpy.array(image_tokens))
print("detokenizing image")
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
image = Image.fromarray(image.to('cpu').detach().numpy())
return image

View File

@ -1,118 +1,83 @@
from random import sample
import numpy
import os
from PIL import Image
from typing import Dict
from torch import LongTensor, FloatTensor
from torch import LongTensor
import torch
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
from .models.vqgan_detokenizer import VQGanDetokenizer
from .load_params import convert_dalle_bart_torch_from_flax_params
from .min_dalle import MinDalle
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
from .load_params import (
load_vqgan_torch_params,
convert_dalle_bart_torch_from_flax_params
)
class MinDalleTorch(MinDalle):
def __init__(self, is_mega: bool, sample_token_count: int = 256):
super().__init__(is_mega)
print("initializing MinDalleTorch")
print("loading encoder")
self.encoder = DalleBartEncoderTorch(
layer_count = self.config['encoder_layers'],
embed_count = self.config['d_model'],
attention_head_count = self.config['encoder_attention_heads'],
text_vocab_count = self.config['encoder_vocab_size'],
text_token_count = self.config['max_text_length'],
glu_embed_count = self.config['encoder_ffn_dim']
)
encoder_params = convert_dalle_bart_torch_from_flax_params(
self.model_params.pop('encoder'),
layer_count=self.config['encoder_layers'],
is_encoder=True
)
self.encoder.load_state_dict(encoder_params, strict=False)
print("loading decoder")
self.decoder = DalleBartDecoderTorch(
image_vocab_size = self.config['image_vocab_size'],
image_token_count = self.config['image_length'],
sample_token_count = sample_token_count,
embed_count = self.config['d_model'],
attention_head_count = self.config['decoder_attention_heads'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
batch_count = 2,
start_token = self.config['decoder_start_token_id'],
is_verbose = True
)
decoder_params = convert_dalle_bart_torch_from_flax_params(
self.model_params.pop('decoder'),
layer_count=self.config['decoder_layers'],
is_encoder=False
)
self.decoder.load_state_dict(decoder_params, strict=False)
if torch.cuda.is_available():
self.encoder = self.encoder.cuda()
self.decoder = self.decoder.cuda()
self.detokenizer = self.detokenizer.cuda()
def encode_torch(
text_tokens: LongTensor,
config: dict,
params: dict
) -> FloatTensor:
print("loading torch encoder")
encoder = DalleBartEncoderTorch(
layer_count = config['encoder_layers'],
embed_count = config['d_model'],
attention_head_count = config['encoder_attention_heads'],
text_vocab_count = config['encoder_vocab_size'],
text_token_count = config['max_text_length'],
glu_embed_count = config['encoder_ffn_dim']
)
encoder_params = convert_dalle_bart_torch_from_flax_params(
params.pop('encoder'),
layer_count=config['encoder_layers'],
is_encoder=True
)
encoder.load_state_dict(encoder_params, strict=False)
del encoder_params
if torch.cuda.is_available(): encoder = encoder.cuda()
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
text_tokens = self.tokenize_text(text)
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
print("encoding text tokens")
encoder_state = encoder(text_tokens)
del encoder
return encoder_state
print("encoding text tokens")
encoder_state = self.encoder.forward(text_tokens)
print("sampling image tokens")
torch.manual_seed(seed)
image_tokens = self.decoder.forward(text_tokens, encoder_state)
return image_tokens
def decode_torch(
text_tokens: LongTensor,
encoder_state: FloatTensor,
config: dict,
seed: int,
params: dict,
image_token_count: int
) -> LongTensor:
print("loading torch decoder")
decoder = DalleBartDecoderTorch(
image_vocab_size = config['image_vocab_size'],
image_token_count = config['image_length'],
sample_token_count = image_token_count,
embed_count = config['d_model'],
attention_head_count = config['decoder_attention_heads'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
batch_count = 2,
start_token = config['decoder_start_token_id'],
is_verbose = True
)
decoder_params = convert_dalle_bart_torch_from_flax_params(
params.pop('decoder'),
layer_count=config['decoder_layers'],
is_encoder=False
)
decoder.load_state_dict(decoder_params, strict=False)
del decoder_params
if torch.cuda.is_available(): decoder = decoder.cuda()
print("sampling image tokens")
torch.manual_seed(seed)
image_tokens = decoder.forward(text_tokens, encoder_state)
return image_tokens
def generate_image_tokens_torch(
text_tokens: numpy.ndarray,
seed: int,
config: dict,
params: dict,
image_token_count: int
) -> LongTensor:
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
encoder_state = encode_torch(
text_tokens,
config,
params
)
image_tokens = decode_torch(
text_tokens,
encoder_state,
config,
seed,
params,
image_token_count
)
return image_tokens
def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray:
print("detokenizing image")
model_path = './pretrained/vqgan'
params = load_vqgan_torch_params(model_path)
detokenizer = VQGanDetokenizer()
detokenizer.load_state_dict(params)
if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda()
image = detokenizer.forward(image_tokens).to(torch.uint8)
del detokenizer, params
return image.to('cpu').detach().numpy()
def generate_image(self, text: str, seed: int) -> Image.Image:
image_tokens = self.generate_image_tokens(text, seed)
print("detokenizing image")
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
image = Image.fromarray(image.to('cpu').detach().numpy())
return image

View File

@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax):
class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(self,
def __call__(
self,
decoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
def __call__(
self,
decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module):
self.final_ln = nn.LayerNorm(use_scale=False)
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
def __call__(self,
def __call__(
self,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module):
decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state
def sample_image_tokens(self,
def sample_image_tokens(
self,
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
prng_key: jax.random.PRNGKey,

View File

@ -16,40 +16,34 @@ class DecoderCrossAttentionTorch(AttentionTorch):
keys = self.k_proj.forward(encoder_state)
values = self.v_proj.forward(encoder_state)
queries = self.q_proj.forward(decoder_state)
query_shape = queries.shape[:2] + (self.head_count, -1)
key_value_shape = keys.shape[:2] + (self.head_count, -1)
keys = keys.reshape(key_value_shape)
values = values.reshape(key_value_shape)
queries = queries.reshape(query_shape)
queries /= queries.shape[-1] ** 0.5
return super().forward(keys, values, queries, attention_mask)
class DecoderSelfAttentionTorch(AttentionTorch):
def forward(self,
def forward(
self,
decoder_state: FloatTensor,
keys_values: FloatTensor,
attention_mask: BoolTensor,
token_mask: BoolTensor
) -> Tuple[FloatTensor, FloatTensor]:
batch_count = decoder_state.shape[0]
shape = (batch_count, 1) + keys_values.shape[2:]
keys = self.k_proj.forward(decoder_state).view(shape)
values = self.v_proj.forward(decoder_state).view(shape)
keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state)
keys_values = torch.where(
token_mask[None, :, None, None],
token_mask[None, :, None],
torch.cat([keys, values]),
keys_values
)
queries = self.q_proj.forward(decoder_state).reshape(shape)
queries /= queries.shape[-1] ** 0.5
keys, values = keys_values[:batch_count], keys_values[batch_count:]
decoder_state = super().forward(keys, values, queries, attention_mask)
return decoder_state, keys_values
class DecoderLayerTorch(nn.Module):
def __init__(self,
def __init__(
self,
image_token_count: int,
head_count: int,
embed_count: int,
@ -69,7 +63,8 @@ class DecoderLayerTorch(nn.Module):
if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda()
def forward(self,
def forward(
self,
decoder_state: FloatTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
@ -111,7 +106,8 @@ class DecoderLayerTorch(nn.Module):
class DalleBartDecoderTorch(nn.Module):
def __init__(self,
def __init__(
self,
image_vocab_size: int,
image_token_count: int,
sample_token_count: int,
@ -146,8 +142,7 @@ class DalleBartDecoderTorch(nn.Module):
self.keys_values_state_shape = (
layer_count * 2 * batch_count,
image_token_count,
attention_head_count,
embed_count // attention_head_count
embed_count
)
self.zero_prob = torch.zeros([1])
self.token_indices = torch.arange(self.sample_token_count)
@ -158,7 +153,8 @@ class DalleBartDecoderTorch(nn.Module):
self.start_token = self.start_token.cuda()
def decode_step(self,
def decode_step(
self,
text_tokens: LongTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
@ -183,7 +179,6 @@ class DalleBartDecoderTorch(nn.Module):
token_index[:1]
)
keys_values.append(keys_values_layer)
keys_values = torch.cat(keys_values, dim=0)
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
a = self.condition_factor
@ -195,10 +190,11 @@ class DalleBartDecoderTorch(nn.Module):
self.zero_prob,
torch.exp(logits - top_logits[0])
)
return probs, keys_values
return probs, torch.cat(keys_values)
def forward(self,
def forward(
self,
text_tokens: LongTensor,
encoder_state: FloatTensor
) -> LongTensor:

View File

@ -34,7 +34,8 @@ class AttentionFlax(nn.Module):
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
def forward(self,
def forward(
self,
keys: jnp.ndarray,
values: jnp.ndarray,
queries: jnp.ndarray,
@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
def __call__(
self,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:

View File

@ -37,12 +37,18 @@ class AttentionTorch(nn.Module):
self.one = torch.ones((1, 1))
if torch.cuda.is_available(): self.one = self.one.cuda()
def forward(self,
def forward(
self,
keys: FloatTensor,
values: FloatTensor,
queries: FloatTensor,
attention_mask: BoolTensor
) -> FloatTensor:
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
values = values.reshape(values.shape[:2] + (self.head_count, -1))
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias = torch.where(
attention_mask,
self.one * 0,
@ -72,11 +78,9 @@ class EncoderSelfAttentionTorch(AttentionTorch):
encoder_state: FloatTensor,
attention_mask: BoolTensor
) -> FloatTensor:
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
keys = self.k_proj.forward(encoder_state).reshape(shape_split)
values = self.v_proj.forward(encoder_state).reshape(shape_split)
queries = self.q_proj.forward(encoder_state).reshape(shape_split)
queries /= queries.shape[-1] ** 0.5
keys = self.k_proj.forward(encoder_state)
values = self.v_proj.forward(encoder_state)
queries = self.q_proj.forward(encoder_state)
return super().forward(keys, values, queries, attention_mask)
@ -105,7 +109,8 @@ class EncoderLayerTorch(nn.Module):
class DalleBartEncoderTorch(nn.Module):
def __init__(self,
def __init__(
self,
layer_count: int,
embed_count: int,
attention_head_count: int,

View File

@ -8,7 +8,7 @@ class TextTokenizer:
pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
def __call__(self, text: str) -> List[int]:
def tokenize(self, text: str) -> List[int]:
sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>']

2
requirements.txt vendored
View File

@ -1,3 +1,3 @@
torch
flax==0.4.2
flax==0.5.2
wandb