Merge branch 'main' into patch-1
This commit is contained in:
commit
e3329a7f64
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
* linguist-vendored
|
||||||
|
*.py linguist-vendored=false
|
19
README.md
vendored
19
README.md
vendored
|
@ -1,28 +1,33 @@
|
||||||
# min(DALL·E)
|
# min(DALL·E)
|
||||||
|
|
||||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
||||||
|
[![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
|
||||||
|
|
||||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model.
|
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which are saved with `torch.save` the first time the model is loaded), and wandb is only used to download the models.
|
||||||
|
|
||||||
|
It currently takes **7.4 seconds** to generate an image with DALL·E Mega with PyTorch on a standard GPU runtime in Colab
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
Run `sh setup.sh` to install dependencies and download pretrained models. In the bash script, Git LFS is used to download the VQGan detokenizer from Hugging Face and the Weight & Biases python package is used to download the DALL·E Mini and DALL·E Mega transformer models. These models can also be downloaded manually:
|
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually here:
|
||||||
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
||||||
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
||||||
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
Use the command line python script `image_from_text.py` to generate images. Here are some examples:
|
Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize either `MinDalleTorch` or `MinDalleFlax`, then call `generate_image` with some text and a seed. See the colab for an example.
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
```
|
```
|
||||||
python image_from_text.py --text='alien life' --seed=7
|
python image_from_text.py --text='artificial intelligence' --torch
|
||||||
```
|
```
|
||||||
![Alien](examples/alien.png)
|
![Alien](examples/artificial_intelligence.png)
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=4
|
python image_from_text.py --text='a comfy chair that looks like an avocado' --torch --mega --seed=10
|
||||||
```
|
```
|
||||||
![Avocado Armchair](examples/avocado_armchair.png)
|
![Avocado Armchair](examples/avocado_armchair.png)
|
||||||
|
|
||||||
|
|
12
cog.yaml
vendored
Normal file
12
cog.yaml
vendored
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
build:
|
||||||
|
cuda: "11.0"
|
||||||
|
gpu: true
|
||||||
|
python_version: "3.8"
|
||||||
|
system_packages:
|
||||||
|
- "libgl1-mesa-glx"
|
||||||
|
- "libglib2.0-0"
|
||||||
|
python_packages:
|
||||||
|
- "torch==1.10.1"
|
||||||
|
- "flax==0.5.2"
|
||||||
|
|
||||||
|
predict: "predict.py:Predictor"
|
BIN
examples/alien.png
vendored
BIN
examples/alien.png
vendored
Binary file not shown.
Before Width: | Height: | Size: 55 KiB |
BIN
examples/artificial_intelligence.png
vendored
Normal file
BIN
examples/artificial_intelligence.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 127 KiB |
BIN
examples/avocado_armchair.png
vendored
BIN
examples/avocado_armchair.png
vendored
Binary file not shown.
Before Width: | Height: | Size: 90 KiB After Width: | Height: | Size: 101 KiB |
|
@ -2,8 +2,8 @@ import argparse
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from min_dalle.generate_image import generate_image_from_text
|
from min_dalle.min_dalle_torch import MinDalleTorch
|
||||||
|
from min_dalle.min_dalle_flax import MinDalleFlax
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--mega', action='store_true')
|
parser.add_argument('--mega', action='store_true')
|
||||||
|
@ -12,10 +12,10 @@ parser.set_defaults(mega=False)
|
||||||
parser.add_argument('--torch', action='store_true')
|
parser.add_argument('--torch', action='store_true')
|
||||||
parser.add_argument('--no-torch', dest='torch', action='store_false')
|
parser.add_argument('--no-torch', dest='torch', action='store_false')
|
||||||
parser.set_defaults(torch=False)
|
parser.set_defaults(torch=False)
|
||||||
parser.add_argument('--text', type=str)
|
parser.add_argument('--text', type=str, default='alien life')
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=7)
|
||||||
parser.add_argument('--image_path', type=str, default='generated')
|
parser.add_argument('--image_path', type=str, default='generated')
|
||||||
parser.add_argument('--image_token_count', type=int, default=256) # for debugging
|
parser.add_argument('--token_count', type=int, default=256) # for debugging
|
||||||
|
|
||||||
|
|
||||||
def ascii_from_image(image: Image.Image, size: int) -> str:
|
def ascii_from_image(image: Image.Image, size: int) -> str:
|
||||||
|
@ -36,19 +36,41 @@ def save_image(image: Image.Image, path: str):
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image(
|
||||||
|
is_torch: bool,
|
||||||
|
is_mega: bool,
|
||||||
|
text: str,
|
||||||
|
seed: int,
|
||||||
|
image_path: str,
|
||||||
|
token_count: int
|
||||||
|
):
|
||||||
|
is_reusable = False
|
||||||
|
if is_torch:
|
||||||
|
image_generator = MinDalleTorch(is_mega, is_reusable, token_count)
|
||||||
|
|
||||||
|
if token_count < image_generator.config['image_length']:
|
||||||
|
image_tokens = image_generator.generate_image_tokens(text, seed)
|
||||||
|
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
image = image_generator.generate_image(text, seed)
|
||||||
|
|
||||||
|
else:
|
||||||
|
image_generator = MinDalleFlax(is_mega, is_reusable)
|
||||||
|
image = image_generator.generate_image(text, seed)
|
||||||
|
|
||||||
|
save_image(image, image_path)
|
||||||
|
print(ascii_from_image(image, size=128))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print(args)
|
print(args)
|
||||||
|
generate_image(
|
||||||
image = generate_image_from_text(
|
is_torch=args.torch,
|
||||||
text = args.text,
|
is_mega=args.mega,
|
||||||
is_mega = args.mega,
|
text=args.text,
|
||||||
is_torch = args.torch,
|
seed=args.seed,
|
||||||
seed = args.seed,
|
image_path=args.image_path,
|
||||||
image_token_count = args.image_token_count
|
token_count=args.token_count
|
||||||
)
|
)
|
||||||
|
|
||||||
if image != None:
|
|
||||||
save_image(image, args.image_path)
|
|
||||||
print(ascii_from_image(image, size=128))
|
|
130
min_dalle.ipynb
vendored
130
min_dalle.ipynb
vendored
File diff suppressed because one or more lines are too long
|
@ -1,77 +0,0 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import numpy
|
|
||||||
from PIL import Image
|
|
||||||
from typing import Tuple, List
|
|
||||||
|
|
||||||
from min_dalle.load_params import load_dalle_bart_flax_params
|
|
||||||
from min_dalle.text_tokenizer import TextTokenizer
|
|
||||||
from min_dalle.min_dalle_flax import generate_image_tokens_flax
|
|
||||||
from min_dalle.min_dalle_torch import (
|
|
||||||
generate_image_tokens_torch,
|
|
||||||
detokenize_torch
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
|
|
||||||
print("parsing metadata from {}".format(path))
|
|
||||||
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
|
|
||||||
assert(os.path.exists(os.path.join(path, f)))
|
|
||||||
with open(path + '/config.json', 'r') as f:
|
|
||||||
config = json.load(f)
|
|
||||||
with open(path + '/vocab.json') as f:
|
|
||||||
vocab = json.load(f)
|
|
||||||
with open(path + '/merges.txt') as f:
|
|
||||||
merges = f.read().split("\n")[1:-1]
|
|
||||||
return config, vocab, merges
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_text(
|
|
||||||
text: str,
|
|
||||||
config: dict,
|
|
||||||
vocab: dict,
|
|
||||||
merges: List[str]
|
|
||||||
) -> numpy.ndarray:
|
|
||||||
print("tokenizing text")
|
|
||||||
tokens = TextTokenizer(vocab, merges)(text)
|
|
||||||
print("text tokens", tokens)
|
|
||||||
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
|
|
||||||
text_tokens[0, :len(tokens)] = tokens
|
|
||||||
text_tokens[1, :2] = [tokens[0], tokens[-1]]
|
|
||||||
return text_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image_from_text(
|
|
||||||
text: str,
|
|
||||||
is_mega: bool = False,
|
|
||||||
is_torch: bool = False,
|
|
||||||
seed: int = 0,
|
|
||||||
image_token_count: int = 256
|
|
||||||
) -> Image.Image:
|
|
||||||
model_name = 'mega' if is_mega else 'mini'
|
|
||||||
model_path = './pretrained/dalle_bart_{}'.format(model_name)
|
|
||||||
config, vocab, merges = load_dalle_bart_metadata(model_path)
|
|
||||||
text_tokens = tokenize_text(text, config, vocab, merges)
|
|
||||||
params_dalle_bart = load_dalle_bart_flax_params(model_path)
|
|
||||||
|
|
||||||
image_tokens = numpy.zeros(config['image_length'])
|
|
||||||
if is_torch:
|
|
||||||
image_tokens[:image_token_count] = generate_image_tokens_torch(
|
|
||||||
text_tokens = text_tokens,
|
|
||||||
seed = seed,
|
|
||||||
config = config,
|
|
||||||
params = params_dalle_bart,
|
|
||||||
image_token_count = image_token_count
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_tokens[...] = generate_image_tokens_flax(
|
|
||||||
text_tokens = text_tokens,
|
|
||||||
seed = seed,
|
|
||||||
config = config,
|
|
||||||
params = params_dalle_bart,
|
|
||||||
)
|
|
||||||
|
|
||||||
if image_token_count == config['image_length']:
|
|
||||||
image = detokenize_torch(image_tokens)
|
|
||||||
return Image.fromarray(image)
|
|
||||||
else:
|
|
||||||
return None
|
|
|
@ -1,17 +1,17 @@
|
||||||
import os
|
import os
|
||||||
import numpy
|
import numpy
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from flax import traverse_util, serialization
|
from flax.traverse_util import flatten_dict
|
||||||
|
from flax.serialization import msgpack_restore
|
||||||
import torch
|
import torch
|
||||||
torch.no_grad()
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
||||||
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
|
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
|
||||||
params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read())
|
params: Dict[str, numpy.ndarray] = msgpack_restore(f.read())
|
||||||
|
|
||||||
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.')
|
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
|
||||||
|
|
||||||
for i in list(P.keys()):
|
for i in list(P.keys()):
|
||||||
j = i
|
j = i
|
||||||
|
@ -30,7 +30,6 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
|
||||||
for i in P:
|
for i in P:
|
||||||
P[i] = torch.tensor(P[i])
|
P[i] = torch.tensor(P[i])
|
||||||
if torch.cuda.is_available(): P[i] = P[i].cuda()
|
|
||||||
|
|
||||||
P['embedding.weight'] = P.pop('quantize.embedding.embedding')
|
P['embedding.weight'] = P.pop('quantize.embedding.embedding')
|
||||||
|
|
||||||
|
@ -43,7 +42,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
|
||||||
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
|
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
|
||||||
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
|
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
|
||||||
params = serialization.msgpack_restore(f.read())
|
params = msgpack_restore(f.read())
|
||||||
|
|
||||||
for codec in ['encoder', 'decoder']:
|
for codec in ['encoder', 'decoder']:
|
||||||
k = 'FlaxBart{}Layers'.format(codec.title())
|
k = 'FlaxBart{}Layers'.format(codec.title())
|
||||||
|
@ -82,12 +81,10 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||||
layer_count: int,
|
layer_count: int,
|
||||||
is_encoder: bool
|
is_encoder: bool
|
||||||
) -> dict:
|
) -> dict:
|
||||||
P = deepcopy(params)
|
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
|
||||||
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(P, sep='.')
|
|
||||||
|
|
||||||
for i in P:
|
for i in P:
|
||||||
P[i] = torch.tensor(P[i])
|
P[i] = torch.tensor(P[i]).to(torch.float16)
|
||||||
if torch.cuda.is_available(): P[i] = P[i].cuda()
|
|
||||||
|
|
||||||
for i in list(P):
|
for i in list(P):
|
||||||
if 'kernel' in i:
|
if 'kernel' in i:
|
||||||
|
@ -108,3 +105,28 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||||
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
|
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
|
||||||
P['embed_positions.weight'] = P.pop('embed_positions.embedding')
|
P['embed_positions.weight'] = P.pop('embed_positions.embedding')
|
||||||
return P
|
return P
|
||||||
|
|
||||||
|
|
||||||
|
def convert_and_save_torch_params(is_mega: bool, model_path: str):
|
||||||
|
print("converting params to torch")
|
||||||
|
layer_count = 24 if is_mega else 12
|
||||||
|
flax_params = load_dalle_bart_flax_params(model_path)
|
||||||
|
encoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||||
|
flax_params['encoder'],
|
||||||
|
layer_count=layer_count,
|
||||||
|
is_encoder=True
|
||||||
|
)
|
||||||
|
decoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||||
|
flax_params['decoder'],
|
||||||
|
layer_count=layer_count,
|
||||||
|
is_encoder=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in decoder_params:
|
||||||
|
decoder_params[i] = decoder_params[i].to(torch.float16)
|
||||||
|
|
||||||
|
for i in encoder_params:
|
||||||
|
encoder_params[i] = encoder_params[i].to(torch.float16)
|
||||||
|
|
||||||
|
torch.save(encoder_params, os.path.join(model_path, 'encoder.pt'))
|
||||||
|
torch.save(decoder_params, os.path.join(model_path, 'decoder.pt'))
|
46
min_dalle/min_dalle_base.py
Normal file
46
min_dalle/min_dalle_base.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
from .text_tokenizer import TextTokenizer
|
||||||
|
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
|
||||||
|
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||||
|
|
||||||
|
class MinDalleBase:
|
||||||
|
def __init__(self, is_mega: bool):
|
||||||
|
self.is_mega = is_mega
|
||||||
|
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
||||||
|
self.model_path = os.path.join('pretrained', model_name)
|
||||||
|
|
||||||
|
print("reading files from {}".format(self.model_path))
|
||||||
|
config_path = os.path.join(self.model_path, 'config.json')
|
||||||
|
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
||||||
|
merges_path = os.path.join(self.model_path, 'merges.txt')
|
||||||
|
|
||||||
|
with open(config_path, 'r', encoding='utf8') as f:
|
||||||
|
self.config = json.load(f)
|
||||||
|
with open(vocab_path, 'r', encoding='utf8') as f:
|
||||||
|
vocab = json.load(f)
|
||||||
|
with open(merges_path, 'r', encoding='utf8') as f:
|
||||||
|
merges = f.read().split("\n")[1:-1]
|
||||||
|
|
||||||
|
self.tokenizer = TextTokenizer(vocab, merges)
|
||||||
|
|
||||||
|
|
||||||
|
def init_detokenizer(self):
|
||||||
|
print("initializing VQGanDetokenizer")
|
||||||
|
params = load_vqgan_torch_params('./pretrained/vqgan')
|
||||||
|
self.detokenizer = VQGanDetokenizer()
|
||||||
|
self.detokenizer.load_state_dict(params)
|
||||||
|
del params
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_text(self, text: str) -> numpy.ndarray:
|
||||||
|
print("tokenizing text")
|
||||||
|
tokens = self.tokenizer.tokenize(text)
|
||||||
|
print("text tokens", tokens)
|
||||||
|
text_token_count = self.config['max_text_length']
|
||||||
|
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
|
||||||
|
text_tokens[0, :len(tokens)] = tokens
|
||||||
|
text_tokens[1, :2] = [tokens[0], tokens[-1]]
|
||||||
|
return text_tokens
|
|
@ -1,79 +1,80 @@
|
||||||
import jax
|
import jax
|
||||||
from jax import numpy as jnp
|
|
||||||
import numpy
|
import numpy
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .min_dalle_base import MinDalleBase
|
||||||
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
||||||
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
||||||
|
|
||||||
|
from .load_params import load_dalle_bart_flax_params
|
||||||
def encode_flax(
|
|
||||||
text_tokens: numpy.ndarray,
|
|
||||||
config: dict,
|
|
||||||
params: dict
|
|
||||||
) -> jnp.ndarray:
|
|
||||||
print("loading flax encoder")
|
|
||||||
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
|
|
||||||
attention_head_count = config['encoder_attention_heads'],
|
|
||||||
embed_count = config['d_model'],
|
|
||||||
glu_embed_count = config['encoder_ffn_dim'],
|
|
||||||
text_token_count = config['max_text_length'],
|
|
||||||
text_vocab_count = config['encoder_vocab_size'],
|
|
||||||
layer_count = config['encoder_layers']
|
|
||||||
).bind({'params': params.pop('encoder')})
|
|
||||||
|
|
||||||
print("encoding text tokens")
|
|
||||||
encoder_state = encoder(text_tokens)
|
|
||||||
del encoder
|
|
||||||
return encoder_state
|
|
||||||
|
|
||||||
|
|
||||||
def decode_flax(
|
class MinDalleFlax(MinDalleBase):
|
||||||
text_tokens: jnp.ndarray,
|
def __init__(self, is_mega: bool, is_reusable: bool = True):
|
||||||
encoder_state: jnp.ndarray,
|
super().__init__(is_mega)
|
||||||
config: dict,
|
self.is_reusable = is_reusable
|
||||||
seed: int,
|
print("initializing MinDalleFlax")
|
||||||
params: dict
|
self.model_params = load_dalle_bart_flax_params(self.model_path)
|
||||||
) -> jnp.ndarray:
|
if is_reusable:
|
||||||
print("loading flax decoder")
|
self.init_encoder()
|
||||||
decoder = DalleBartDecoderFlax(
|
self.init_decoder()
|
||||||
image_token_count = config['image_length'],
|
self.init_detokenizer()
|
||||||
text_token_count = config['max_text_length'],
|
|
||||||
image_vocab_count = config['image_vocab_size'],
|
|
||||||
attention_head_count = config['decoder_attention_heads'],
|
|
||||||
embed_count = config['d_model'],
|
|
||||||
glu_embed_count = config['decoder_ffn_dim'],
|
|
||||||
layer_count = config['decoder_layers'],
|
|
||||||
start_token = config['decoder_start_token_id']
|
|
||||||
)
|
|
||||||
print("sampling image tokens")
|
|
||||||
image_tokens = decoder.sample_image_tokens(
|
|
||||||
text_tokens,
|
|
||||||
encoder_state,
|
|
||||||
jax.random.PRNGKey(seed),
|
|
||||||
params.pop('decoder')
|
|
||||||
)
|
|
||||||
del decoder
|
|
||||||
return image_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image_tokens_flax(
|
def init_encoder(self):
|
||||||
text_tokens: numpy.ndarray,
|
print("initializing DalleBartEncoderFlax")
|
||||||
seed: int,
|
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
|
||||||
config: dict,
|
attention_head_count = self.config['encoder_attention_heads'],
|
||||||
params: dict
|
embed_count = self.config['d_model'],
|
||||||
) -> numpy.ndarray:
|
glu_embed_count = self.config['encoder_ffn_dim'],
|
||||||
encoder_state = encode_flax(
|
text_token_count = self.config['max_text_length'],
|
||||||
text_tokens,
|
text_vocab_count = self.config['encoder_vocab_size'],
|
||||||
config,
|
layer_count = self.config['encoder_layers']
|
||||||
params
|
).bind({'params': self.model_params.pop('encoder')})
|
||||||
)
|
|
||||||
image_tokens = decode_flax(
|
|
||||||
text_tokens,
|
def init_decoder(self):
|
||||||
encoder_state,
|
print("initializing DalleBartDecoderFlax")
|
||||||
config,
|
self.decoder = DalleBartDecoderFlax(
|
||||||
seed,
|
image_token_count = self.config['image_length'],
|
||||||
params
|
text_token_count = self.config['max_text_length'],
|
||||||
)
|
image_vocab_count = self.config['image_vocab_size'],
|
||||||
image_tokens = numpy.array(image_tokens)
|
attention_head_count = self.config['decoder_attention_heads'],
|
||||||
print("image tokens", list(image_tokens))
|
embed_count = self.config['d_model'],
|
||||||
return image_tokens
|
glu_embed_count = self.config['decoder_ffn_dim'],
|
||||||
|
layer_count = self.config['decoder_layers'],
|
||||||
|
start_token = self.config['decoder_start_token_id']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image(self, text: str, seed: int) -> Image.Image:
|
||||||
|
text_tokens = self.tokenize_text(text)
|
||||||
|
|
||||||
|
if not self.is_reusable: self.init_encoder()
|
||||||
|
print("encoding text tokens")
|
||||||
|
encoder_state = self.encoder(text_tokens)
|
||||||
|
if not self.is_reusable: del self.encoder
|
||||||
|
|
||||||
|
if not self.is_reusable:
|
||||||
|
self.init_decoder()
|
||||||
|
params = self.model_params.pop('decoder')
|
||||||
|
else:
|
||||||
|
params = self.model_params['decoder']
|
||||||
|
print("sampling image tokens")
|
||||||
|
image_tokens = self.decoder.sample_image_tokens(
|
||||||
|
text_tokens,
|
||||||
|
encoder_state,
|
||||||
|
jax.random.PRNGKey(seed),
|
||||||
|
params
|
||||||
|
)
|
||||||
|
if not self.is_reusable: del self.decoder
|
||||||
|
|
||||||
|
image_tokens = torch.tensor(numpy.array(image_tokens))
|
||||||
|
|
||||||
|
if not self.is_reusable: self.init_detokenizer()
|
||||||
|
print("detokenizing image")
|
||||||
|
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
|
||||||
|
if not self.is_reusable: del self.detokenizer
|
||||||
|
image = Image.fromarray(image.to('cpu').detach().numpy())
|
||||||
|
return image
|
|
@ -1,113 +1,114 @@
|
||||||
import numpy
|
import os
|
||||||
|
from PIL import Image
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from torch import LongTensor, FloatTensor
|
from torch import LongTensor
|
||||||
import torch
|
import torch
|
||||||
torch.no_grad()
|
torch.set_grad_enabled(False)
|
||||||
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
|
||||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
from .load_params import (
|
||||||
|
convert_and_save_torch_params,
|
||||||
|
load_dalle_bart_flax_params
|
||||||
|
)
|
||||||
|
from .min_dalle_base import MinDalleBase
|
||||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||||
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||||
|
|
||||||
from .load_params import (
|
|
||||||
load_vqgan_torch_params,
|
class MinDalleTorch(MinDalleBase):
|
||||||
convert_dalle_bart_torch_from_flax_params
|
def __init__(
|
||||||
)
|
self,
|
||||||
|
is_mega: bool,
|
||||||
|
is_reusable: bool = True,
|
||||||
|
token_count: int = 256
|
||||||
|
):
|
||||||
|
print("initializing MinDalleTorch")
|
||||||
|
super().__init__(is_mega)
|
||||||
|
self.is_reusable = is_reusable
|
||||||
|
self.token_count = token_count
|
||||||
|
|
||||||
|
if not is_mega:
|
||||||
|
self.model_params = load_dalle_bart_flax_params(self.model_path)
|
||||||
|
|
||||||
|
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
|
||||||
|
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
|
||||||
|
|
||||||
|
is_converted = os.path.exists(self.encoder_params_path)
|
||||||
|
is_converted &= os.path.exists(self.decoder_params_path)
|
||||||
|
if not is_converted:
|
||||||
|
convert_and_save_torch_params(is_mega, self.model_path)
|
||||||
|
|
||||||
|
if is_reusable:
|
||||||
|
self.init_encoder()
|
||||||
|
self.init_decoder()
|
||||||
|
self.init_detokenizer()
|
||||||
|
|
||||||
|
|
||||||
def encode_torch(
|
def init_encoder(self):
|
||||||
text_tokens: LongTensor,
|
print("initializing DalleBartEncoderTorch")
|
||||||
config: dict,
|
self.encoder = DalleBartEncoderTorch(
|
||||||
params: dict
|
layer_count = self.config['encoder_layers'],
|
||||||
) -> FloatTensor:
|
embed_count = self.config['d_model'],
|
||||||
print("loading torch encoder")
|
attention_head_count = self.config['encoder_attention_heads'],
|
||||||
encoder = DalleBartEncoderTorch(
|
text_vocab_count = self.config['encoder_vocab_size'],
|
||||||
layer_count = config['encoder_layers'],
|
text_token_count = self.config['max_text_length'],
|
||||||
embed_count = config['d_model'],
|
glu_embed_count = self.config['encoder_ffn_dim']
|
||||||
attention_head_count = config['encoder_attention_heads'],
|
)
|
||||||
text_vocab_count = config['encoder_vocab_size'],
|
params = torch.load(self.encoder_params_path)
|
||||||
text_token_count = config['max_text_length'],
|
self.encoder.load_state_dict(params, strict=False)
|
||||||
glu_embed_count = config['encoder_ffn_dim']
|
del params
|
||||||
)
|
if torch.cuda.is_available(): self.encoder = self.encoder.cuda()
|
||||||
encoder_params = convert_dalle_bart_torch_from_flax_params(
|
|
||||||
params.pop('encoder'),
|
|
||||||
layer_count=config['encoder_layers'],
|
|
||||||
is_encoder=True
|
|
||||||
)
|
|
||||||
encoder.load_state_dict(encoder_params, strict=False)
|
|
||||||
del encoder_params
|
|
||||||
|
|
||||||
print("encoding text tokens")
|
|
||||||
encoder_state = encoder(text_tokens)
|
|
||||||
del encoder
|
|
||||||
return encoder_state
|
|
||||||
|
|
||||||
|
|
||||||
def decode_torch(
|
def init_decoder(self):
|
||||||
text_tokens: LongTensor,
|
print("initializing DalleBartDecoderTorch")
|
||||||
encoder_state: FloatTensor,
|
self.decoder = DalleBartDecoderTorch(
|
||||||
config: dict,
|
image_vocab_size = self.config['image_vocab_size'],
|
||||||
seed: int,
|
image_token_count = self.config['image_length'],
|
||||||
params: dict,
|
sample_token_count = self.token_count,
|
||||||
image_token_count: int
|
embed_count = self.config['d_model'],
|
||||||
) -> LongTensor:
|
attention_head_count = self.config['decoder_attention_heads'],
|
||||||
print("loading torch decoder")
|
glu_embed_count = self.config['decoder_ffn_dim'],
|
||||||
decoder = DalleBartDecoderTorch(
|
layer_count = self.config['decoder_layers'],
|
||||||
image_vocab_size = config['image_vocab_size'],
|
batch_count = 2,
|
||||||
image_token_count = config['image_length'],
|
start_token = self.config['decoder_start_token_id'],
|
||||||
sample_token_count = image_token_count,
|
is_verbose = True
|
||||||
embed_count = config['d_model'],
|
)
|
||||||
attention_head_count = config['decoder_attention_heads'],
|
params = torch.load(self.decoder_params_path)
|
||||||
glu_embed_count = config['decoder_ffn_dim'],
|
self.decoder.load_state_dict(params, strict=False)
|
||||||
layer_count = config['decoder_layers'],
|
del params
|
||||||
batch_count = 2,
|
if torch.cuda.is_available(): self.decoder = self.decoder.cuda()
|
||||||
start_token = config['decoder_start_token_id'],
|
|
||||||
is_verbose = True
|
|
||||||
)
|
|
||||||
decoder_params = convert_dalle_bart_torch_from_flax_params(
|
|
||||||
params.pop('decoder'),
|
|
||||||
layer_count=config['decoder_layers'],
|
|
||||||
is_encoder=False
|
|
||||||
)
|
|
||||||
decoder.load_state_dict(decoder_params, strict=False)
|
|
||||||
del decoder_params
|
|
||||||
|
|
||||||
print("sampling image tokens")
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
image_tokens = decoder.forward(text_tokens, encoder_state)
|
|
||||||
return image_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image_tokens_torch(
|
def init_detokenizer(self):
|
||||||
text_tokens: numpy.ndarray,
|
super().init_detokenizer()
|
||||||
seed: int,
|
if torch.cuda.is_available():
|
||||||
config: dict,
|
self.detokenizer = self.detokenizer.cuda()
|
||||||
params: dict,
|
|
||||||
image_token_count: int
|
|
||||||
) -> LongTensor:
|
|
||||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
|
||||||
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
|
||||||
encoder_state = encode_torch(
|
|
||||||
text_tokens,
|
|
||||||
config,
|
|
||||||
params
|
|
||||||
)
|
|
||||||
image_tokens = decode_torch(
|
|
||||||
text_tokens,
|
|
||||||
encoder_state,
|
|
||||||
config,
|
|
||||||
seed,
|
|
||||||
params,
|
|
||||||
image_token_count
|
|
||||||
)
|
|
||||||
return image_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray:
|
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
||||||
print("detokenizing image")
|
text_tokens = self.tokenize_text(text)
|
||||||
model_path = './pretrained/vqgan'
|
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||||
params = load_vqgan_torch_params(model_path)
|
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||||
detokenizer = VQGanDetokenizer()
|
|
||||||
detokenizer.load_state_dict(params)
|
|
||||||
image = detokenizer.forward(image_tokens).to(torch.uint8)
|
|
||||||
return image.detach().numpy()
|
|
||||||
|
|
||||||
|
if not self.is_reusable: self.init_encoder()
|
||||||
|
print("encoding text tokens")
|
||||||
|
encoder_state = self.encoder.forward(text_tokens)
|
||||||
|
if not self.is_reusable: del self.encoder
|
||||||
|
|
||||||
|
if not self.is_reusable: self.init_decoder()
|
||||||
|
print("sampling image tokens")
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
image_tokens = self.decoder.forward(text_tokens, encoder_state)
|
||||||
|
if not self.is_reusable: del self.decoder
|
||||||
|
return image_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image(self, text: str, seed: int) -> Image.Image:
|
||||||
|
image_tokens = self.generate_image_tokens(text, seed)
|
||||||
|
if not self.is_reusable: self.init_detokenizer()
|
||||||
|
print("detokenizing image")
|
||||||
|
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
|
||||||
|
if not self.is_reusable: del self.detokenizer
|
||||||
|
image = Image.fromarray(image.to('cpu').detach().numpy())
|
||||||
|
return image
|
|
@ -13,46 +13,39 @@ class DecoderCrossAttentionFlax(AttentionFlax):
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
attention_mask: jnp.ndarray,
|
attention_mask: jnp.ndarray,
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
keys: jnp.ndarray = self.k_proj(encoder_state)
|
keys = self.k_proj(encoder_state)
|
||||||
values: jnp.ndarray = self.v_proj(encoder_state)
|
values = self.v_proj(encoder_state)
|
||||||
queries: jnp.ndarray = self.q_proj(decoder_state)
|
queries = self.q_proj(decoder_state)
|
||||||
query_shape = queries.shape[:2] + (self.head_count, -1)
|
|
||||||
key_value_shape = keys.shape[:2] + (self.head_count, -1)
|
|
||||||
keys = keys.reshape(key_value_shape)
|
|
||||||
values = values.reshape(key_value_shape)
|
|
||||||
queries = queries.reshape(query_shape)
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
return self.forward(keys, values, queries, attention_mask)
|
return self.forward(keys, values, queries, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
class DecoderSelfAttentionFlax(AttentionFlax):
|
class DecoderSelfAttentionFlax(AttentionFlax):
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
decoder_state: jnp.ndarray,
|
decoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
attention_state: jnp.ndarray,
|
||||||
values_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray,
|
attention_mask: jnp.ndarray,
|
||||||
state_index: tuple
|
state_index: tuple
|
||||||
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||||
shape_split = decoder_state.shape[:2] + (self.head_count, -1)
|
keys = self.k_proj(decoder_state)
|
||||||
keys_state = lax.dynamic_update_slice(
|
values = self.v_proj(decoder_state)
|
||||||
keys_state,
|
queries = self.q_proj(decoder_state)
|
||||||
self.k_proj(decoder_state).reshape(shape_split),
|
|
||||||
|
attention_state = lax.dynamic_update_slice(
|
||||||
|
attention_state,
|
||||||
|
jnp.concatenate([keys, values]),
|
||||||
state_index
|
state_index
|
||||||
)
|
)
|
||||||
values_state = lax.dynamic_update_slice(
|
batch_count = decoder_state.shape[0]
|
||||||
values_state,
|
keys, values = attention_state[:batch_count], attention_state[batch_count:]
|
||||||
self.v_proj(decoder_state).reshape(shape_split),
|
|
||||||
state_index
|
|
||||||
)
|
|
||||||
queries = self.q_proj(decoder_state).reshape(shape_split)
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
decoder_state = self.forward(
|
decoder_state = self.forward(
|
||||||
keys_state,
|
keys,
|
||||||
values_state,
|
values,
|
||||||
queries,
|
queries,
|
||||||
attention_mask
|
attention_mask
|
||||||
)
|
)
|
||||||
return decoder_state, (keys_state, values_state)
|
return decoder_state, attention_state
|
||||||
|
|
||||||
|
|
||||||
class DalleBartDecoderLayerFlax(nn.Module):
|
class DalleBartDecoderLayerFlax(nn.Module):
|
||||||
|
@ -77,14 +70,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
||||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
decoder_state: jnp.ndarray,
|
decoder_state: jnp.ndarray,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
attention_state: jnp.ndarray,
|
||||||
values_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray,
|
attention_mask: jnp.ndarray,
|
||||||
token_index: int
|
token_index: int
|
||||||
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = decoder_state
|
residual = decoder_state
|
||||||
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
|
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
|
||||||
|
@ -92,12 +85,11 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
||||||
jnp.arange(self.image_token_count) < token_index + 1,
|
jnp.arange(self.image_token_count) < token_index + 1,
|
||||||
(decoder_state.shape[0], 1)
|
(decoder_state.shape[0], 1)
|
||||||
)
|
)
|
||||||
decoder_state, keys_values_state = self.self_attn(
|
decoder_state, attention_state = self.self_attn(
|
||||||
decoder_state,
|
decoder_state,
|
||||||
keys_state,
|
attention_state,
|
||||||
values_state,
|
|
||||||
self_attention_mask,
|
self_attention_mask,
|
||||||
(0, token_index, 0, 0)
|
(0, token_index, 0)
|
||||||
)
|
)
|
||||||
decoder_state = self.self_attn_layer_norm(decoder_state)
|
decoder_state = self.self_attn_layer_norm(decoder_state)
|
||||||
decoder_state = residual + decoder_state
|
decoder_state = residual + decoder_state
|
||||||
|
@ -118,15 +110,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
||||||
decoder_state = self.glu(decoder_state)
|
decoder_state = self.glu(decoder_state)
|
||||||
decoder_state = residual + decoder_state
|
decoder_state = residual + decoder_state
|
||||||
|
|
||||||
return decoder_state, keys_values_state
|
return decoder_state, attention_state
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
class SampleState:
|
class SampleState:
|
||||||
prev_token: jnp.ndarray
|
prev_token: jnp.ndarray
|
||||||
prng_key: jnp.ndarray
|
prng_key: jnp.ndarray
|
||||||
keys_state: jnp.ndarray
|
attention_state: jnp.ndarray
|
||||||
values_state: jnp.ndarray
|
|
||||||
|
|
||||||
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
||||||
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||||
|
@ -157,10 +148,10 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
)
|
)
|
||||||
self.layers = nn.scan(
|
self.layers = nn.scan(
|
||||||
DalleBartDecoderLayerFlax,
|
DalleBartDecoderLayerFlax,
|
||||||
variable_axes = { "params": 0, "cache": 0 },
|
variable_axes = { "params": 0 },
|
||||||
split_rngs = { "params": True },
|
split_rngs = { "params": True },
|
||||||
in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast),
|
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
|
||||||
out_axes = (0, 0),
|
out_axes = 0,
|
||||||
length=self.layer_count,
|
length=self.layer_count,
|
||||||
)(
|
)(
|
||||||
self.image_token_count,
|
self.image_token_count,
|
||||||
|
@ -173,32 +164,32 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
self.final_ln = nn.LayerNorm(use_scale=False)
|
self.final_ln = nn.LayerNorm(use_scale=False)
|
||||||
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
|
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
|
||||||
|
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
attention_state: jnp.ndarray,
|
||||||
values_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray,
|
attention_mask: jnp.ndarray,
|
||||||
prev_token: int,
|
prev_token: int,
|
||||||
token_index: int
|
token_index: int
|
||||||
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||||
batch_count = encoder_state.shape[0]
|
batch_count = encoder_state.shape[0]
|
||||||
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
|
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
|
||||||
decoder_state = self.embed_tokens(prev_token * ones)
|
decoder_state = self.embed_tokens(prev_token * ones)
|
||||||
decoder_state += self.embed_positions(token_index * ones)
|
decoder_state += self.embed_positions(token_index * ones)
|
||||||
decoder_state = self.layernorm_embedding(decoder_state)
|
decoder_state = self.layernorm_embedding(decoder_state)
|
||||||
decoder_state, (keys_state, values_state) = self.layers(
|
decoder_state, attention_state = self.layers(
|
||||||
decoder_state,
|
decoder_state,
|
||||||
encoder_state,
|
encoder_state,
|
||||||
keys_state,
|
attention_state,
|
||||||
values_state,
|
|
||||||
attention_mask,
|
attention_mask,
|
||||||
token_index
|
token_index
|
||||||
)
|
)
|
||||||
decoder_state = self.final_ln(decoder_state)
|
decoder_state = self.final_ln(decoder_state)
|
||||||
decoder_state = self.lm_head(decoder_state)
|
decoder_state = self.lm_head(decoder_state)
|
||||||
return decoder_state, keys_state, values_state
|
return decoder_state, attention_state
|
||||||
|
|
||||||
def sample_image_tokens(self,
|
def sample_image_tokens(
|
||||||
|
self,
|
||||||
text_tokens: jnp.ndarray,
|
text_tokens: jnp.ndarray,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
prng_key: jax.random.PRNGKey,
|
prng_key: jax.random.PRNGKey,
|
||||||
|
@ -209,12 +200,11 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
def sample_next_image_token(
|
def sample_next_image_token(
|
||||||
state: SampleState,
|
state: SampleState,
|
||||||
token_index: int
|
token_index: int
|
||||||
) -> Tuple[SampleState, None]:
|
) -> Tuple[SampleState, jnp.ndarray]:
|
||||||
logits, keys_state, values_state = self.apply(
|
logits, attention_state = self.apply(
|
||||||
{ 'params': params },
|
{ 'params': params },
|
||||||
encoder_state = encoder_state,
|
encoder_state = encoder_state,
|
||||||
keys_state = state.keys_state,
|
attention_state = state.attention_state,
|
||||||
values_state = state.values_state,
|
|
||||||
attention_mask = attention_mask,
|
attention_mask = attention_mask,
|
||||||
prev_token = state.prev_token,
|
prev_token = state.prev_token,
|
||||||
token_index = token_index
|
token_index = token_index
|
||||||
|
@ -229,26 +219,23 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
state = SampleState(
|
state = SampleState(
|
||||||
prev_token = next_token,
|
prev_token = next_token,
|
||||||
prng_key = prng_key_next,
|
prng_key = prng_key_next,
|
||||||
keys_state = keys_state,
|
attention_state = attention_state
|
||||||
values_state = values_state
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return state, next_token
|
return state, next_token
|
||||||
|
|
||||||
batch_count = encoder_state.shape[0]
|
batch_count = encoder_state.shape[0]
|
||||||
state_shape = (
|
attention_state_shape = (
|
||||||
self.layer_count,
|
self.layer_count,
|
||||||
batch_count,
|
batch_count * 2,
|
||||||
self.image_token_count,
|
self.image_token_count,
|
||||||
self.attention_head_count,
|
self.embed_count
|
||||||
self.embed_count // self.attention_head_count
|
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_state = SampleState(
|
initial_state = SampleState(
|
||||||
prev_token = self.start_token,
|
prev_token = self.start_token,
|
||||||
prng_key = prng_key,
|
prng_key = prng_key,
|
||||||
keys_state = jnp.zeros(state_shape),
|
attention_state = jnp.zeros(attention_state_shape)
|
||||||
values_state = jnp.zeros(state_shape)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_, image_tokens = lax.scan(
|
_, image_tokens = lax.scan(
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
||||||
torch.no_grad()
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
|
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
|
||||||
|
|
||||||
|
@ -16,42 +16,35 @@ class DecoderCrossAttentionTorch(AttentionTorch):
|
||||||
keys = self.k_proj.forward(encoder_state)
|
keys = self.k_proj.forward(encoder_state)
|
||||||
values = self.v_proj.forward(encoder_state)
|
values = self.v_proj.forward(encoder_state)
|
||||||
queries = self.q_proj.forward(decoder_state)
|
queries = self.q_proj.forward(decoder_state)
|
||||||
query_shape = queries.shape[:2] + (self.head_count, -1)
|
|
||||||
key_value_shape = keys.shape[:2] + (self.head_count, -1)
|
|
||||||
keys = keys.reshape(key_value_shape)
|
|
||||||
values = values.reshape(key_value_shape)
|
|
||||||
queries = queries.reshape(query_shape)
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
return super().forward(keys, values, queries, attention_mask)
|
return super().forward(keys, values, queries, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
class DecoderSelfAttentionTorch(AttentionTorch):
|
class DecoderSelfAttentionTorch(AttentionTorch):
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
decoder_state: FloatTensor,
|
decoder_state: FloatTensor,
|
||||||
keys_values: FloatTensor,
|
attention_state: FloatTensor,
|
||||||
attention_mask: BoolTensor,
|
attention_mask: BoolTensor,
|
||||||
token_index: LongTensor
|
token_mask: BoolTensor
|
||||||
) -> Tuple[FloatTensor, FloatTensor]:
|
) -> Tuple[FloatTensor, FloatTensor]:
|
||||||
batch_count = decoder_state.shape[0]
|
keys = self.k_proj.forward(decoder_state)
|
||||||
token_count = keys_values.shape[1]
|
values = self.v_proj.forward(decoder_state)
|
||||||
shape = (batch_count, 1) + keys_values.shape[2:]
|
queries = self.q_proj.forward(decoder_state)
|
||||||
keys = self.k_proj.forward(decoder_state).view(shape)
|
attention_state = torch.where(
|
||||||
values = self.v_proj.forward(decoder_state).view(shape)
|
token_mask[None, :, None],
|
||||||
token_mask = torch.arange(token_count) == token_index
|
|
||||||
keys_values = torch.where(
|
|
||||||
token_mask[None, :, None, None],
|
|
||||||
torch.cat([keys, values]),
|
torch.cat([keys, values]),
|
||||||
keys_values
|
attention_state
|
||||||
)
|
)
|
||||||
queries = self.q_proj.forward(decoder_state).reshape(shape)
|
batch_count = decoder_state.shape[0]
|
||||||
queries /= queries.shape[-1] ** 0.5
|
keys = attention_state[:batch_count]
|
||||||
keys, values = keys_values[:batch_count], keys_values[batch_count:]
|
values = attention_state[batch_count:]
|
||||||
decoder_state = super().forward(keys, values, queries, attention_mask)
|
decoder_state = super().forward(keys, values, queries, attention_mask)
|
||||||
return decoder_state, keys_values
|
return decoder_state, attention_state
|
||||||
|
|
||||||
|
|
||||||
class DecoderLayerTorch(nn.Module):
|
class DecoderLayerTorch(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
image_token_count: int,
|
image_token_count: int,
|
||||||
head_count: int,
|
head_count: int,
|
||||||
embed_count: int,
|
embed_count: int,
|
||||||
|
@ -67,23 +60,29 @@ class DecoderLayerTorch(nn.Module):
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||||
self.glu = GLUTorch(embed_count, glu_embed_count)
|
self.glu = GLUTorch(embed_count, glu_embed_count)
|
||||||
|
|
||||||
def forward(self,
|
self.token_indices = torch.arange(self.image_token_count)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.token_indices = self.token_indices.cuda()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
decoder_state: FloatTensor,
|
decoder_state: FloatTensor,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
keys_values_state: FloatTensor,
|
attention_state: FloatTensor,
|
||||||
attention_mask: BoolTensor,
|
attention_mask: BoolTensor,
|
||||||
token_index: LongTensor
|
token_index: LongTensor
|
||||||
) -> Tuple[FloatTensor, FloatTensor]:
|
) -> Tuple[FloatTensor, FloatTensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = decoder_state
|
residual = decoder_state
|
||||||
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
|
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
|
||||||
self_attn_mask = torch.arange(self.image_token_count) < token_index + 1
|
self_attn_mask = self.token_indices < token_index + 1
|
||||||
|
token_mask = self.token_indices == token_index
|
||||||
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
|
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
|
||||||
decoder_state, keys_values_state = self.self_attn.forward(
|
decoder_state, attention_state = self.self_attn.forward(
|
||||||
decoder_state,
|
decoder_state,
|
||||||
keys_values_state,
|
attention_state,
|
||||||
self_attn_mask,
|
self_attn_mask,
|
||||||
token_index
|
token_mask
|
||||||
)
|
)
|
||||||
decoder_state = self.self_attn_layer_norm.forward(decoder_state)
|
decoder_state = self.self_attn_layer_norm.forward(decoder_state)
|
||||||
decoder_state = residual + decoder_state
|
decoder_state = residual + decoder_state
|
||||||
|
@ -104,11 +103,12 @@ class DecoderLayerTorch(nn.Module):
|
||||||
decoder_state = self.glu.forward(decoder_state)
|
decoder_state = self.glu.forward(decoder_state)
|
||||||
decoder_state = residual + decoder_state
|
decoder_state = residual + decoder_state
|
||||||
|
|
||||||
return decoder_state, keys_values_state
|
return decoder_state, attention_state
|
||||||
|
|
||||||
|
|
||||||
class DalleBartDecoderTorch(nn.Module):
|
class DalleBartDecoderTorch(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
image_vocab_size: int,
|
image_vocab_size: int,
|
||||||
image_token_count: int,
|
image_token_count: int,
|
||||||
sample_token_count: int,
|
sample_token_count: int,
|
||||||
|
@ -124,13 +124,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
self.is_verbose = is_verbose
|
self.is_verbose = is_verbose
|
||||||
self.layer_count = layer_count
|
self.layer_count = layer_count
|
||||||
self.sample_token_count = sample_token_count
|
self.sample_token_count = sample_token_count
|
||||||
self.start_token = torch.tensor([start_token]).to(torch.long)
|
self.condition_factor = 10.0
|
||||||
self.pad_token = torch.tensor([1]).to(torch.long)
|
|
||||||
self.condition_factor = torch.tensor([10]).to(torch.float)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.start_token = self.start_token.cuda()
|
|
||||||
self.pad_token = self.pad_token.cuda()
|
|
||||||
self.condition_factor = self.condition_factor.cuda()
|
|
||||||
self.image_token_count = image_token_count
|
self.image_token_count = image_token_count
|
||||||
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count)
|
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count)
|
||||||
self.embed_positions = nn.Embedding(image_token_count, embed_count)
|
self.embed_positions = nn.Embedding(image_token_count, embed_count)
|
||||||
|
@ -146,77 +140,82 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||||
self.final_ln = nn.LayerNorm(embed_count)
|
self.final_ln = nn.LayerNorm(embed_count)
|
||||||
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
|
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
|
||||||
self.keys_values_state_shape = (
|
self.attention_state_shape = (
|
||||||
layer_count * 2 * batch_count,
|
layer_count,
|
||||||
|
2 * batch_count,
|
||||||
image_token_count,
|
image_token_count,
|
||||||
attention_head_count,
|
embed_count
|
||||||
embed_count // attention_head_count
|
|
||||||
)
|
)
|
||||||
|
self.zero_prob = torch.zeros([1])
|
||||||
|
self.token_indices = torch.arange(self.sample_token_count)
|
||||||
|
self.start_token = torch.tensor([start_token]).to(torch.long)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.zero_prob = self.zero_prob.cuda()
|
||||||
|
self.token_indices = self.token_indices.cuda()
|
||||||
|
self.start_token = self.start_token.cuda()
|
||||||
|
|
||||||
|
|
||||||
def decode_step(self,
|
def decode_step(
|
||||||
|
self,
|
||||||
text_tokens: LongTensor,
|
text_tokens: LongTensor,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
keys_values_state: FloatTensor,
|
attention_state: FloatTensor,
|
||||||
prev_token_and_index: LongTensor
|
prev_token: LongTensor,
|
||||||
|
token_index: LongTensor
|
||||||
) -> Tuple[LongTensor, FloatTensor]:
|
) -> Tuple[LongTensor, FloatTensor]:
|
||||||
attention_mask = text_tokens.not_equal(self.pad_token)
|
attention_mask = text_tokens.not_equal(1)
|
||||||
batch_count = encoder_state.shape[0]
|
batch_count = encoder_state.shape[0]
|
||||||
prev_token = torch.cat([prev_token_and_index[:1]] * batch_count)
|
prev_token_batched = torch.cat([prev_token] * batch_count)
|
||||||
token_index = torch.cat([prev_token_and_index[1:]] * batch_count)
|
token_index_batched = torch.cat([token_index] * batch_count)
|
||||||
decoder_state = self.embed_tokens.forward(prev_token)
|
decoder_state = self.embed_tokens.forward(prev_token_batched)
|
||||||
decoder_state += self.embed_positions.forward(token_index)
|
decoder_state += self.embed_positions.forward(token_index_batched)
|
||||||
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
||||||
decoder_state = decoder_state[:, None]
|
decoder_state = decoder_state[:, None]
|
||||||
keys_values = []
|
attention_states_new = []
|
||||||
for i, layer in enumerate(self.layers):
|
for i in range(self.layer_count):
|
||||||
j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count
|
decoder_state, attention_state_layer = self.layers[i].forward(
|
||||||
decoder_state, keys_values_layer = layer.forward(
|
|
||||||
decoder_state,
|
decoder_state,
|
||||||
encoder_state,
|
encoder_state,
|
||||||
keys_values_state[j1:j2],
|
attention_state[i],
|
||||||
attention_mask,
|
attention_mask,
|
||||||
token_index[:1]
|
token_index
|
||||||
)
|
)
|
||||||
keys_values.append(keys_values_layer)
|
attention_states_new.append(attention_state_layer)
|
||||||
keys_values = torch.cat(keys_values, dim=0)
|
|
||||||
decoder_state = self.final_ln(decoder_state)
|
decoder_state = self.final_ln(decoder_state)
|
||||||
logits = self.lm_head(decoder_state)
|
logits = self.lm_head(decoder_state)
|
||||||
a = self.condition_factor
|
a = self.condition_factor
|
||||||
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
|
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||||
|
|
||||||
top_logits = logits.sort(descending=True)[0][:50]
|
top_logits, _ = logits.topk(50, dim=-1)
|
||||||
probs = torch.where(
|
probs = torch.where(
|
||||||
logits < top_logits[-1],
|
logits < top_logits[-1],
|
||||||
torch.zeros([1]),
|
self.zero_prob,
|
||||||
torch.exp(logits - top_logits[0])
|
torch.exp(logits - top_logits[0])
|
||||||
)
|
)
|
||||||
return probs, keys_values
|
return probs, torch.stack(attention_states_new)
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
text_tokens: LongTensor,
|
text_tokens: LongTensor,
|
||||||
encoder_state: FloatTensor
|
encoder_state: FloatTensor
|
||||||
) -> LongTensor:
|
) -> LongTensor:
|
||||||
image_tokens: List[LongTensor] = []
|
image_tokens: List[LongTensor] = []
|
||||||
keys_values_state = torch.zeros(self.keys_values_state_shape)
|
attention_state = torch.zeros(self.attention_state_shape)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
attention_state = attention_state.cuda()
|
||||||
image_token = self.start_token
|
image_token = self.start_token
|
||||||
|
|
||||||
for i in range(self.sample_token_count):
|
for i in range(self.sample_token_count):
|
||||||
token_index = torch.tensor([i]).to(torch.long)
|
probs, attention_state = self.decode_step(
|
||||||
if torch.cuda.is_available(): token_index = token_index.cuda()
|
|
||||||
probs, keys_values_state = self.decode_step(
|
|
||||||
text_tokens = text_tokens,
|
text_tokens = text_tokens,
|
||||||
encoder_state = encoder_state,
|
encoder_state = encoder_state,
|
||||||
keys_values_state = keys_values_state,
|
attention_state = attention_state,
|
||||||
prev_token_and_index = torch.cat([image_token, token_index])
|
prev_token = image_token,
|
||||||
|
token_index = self.token_indices[[i]]
|
||||||
)
|
)
|
||||||
|
|
||||||
image_token = torch.multinomial(probs, 1)
|
image_token = torch.multinomial(probs, 1)
|
||||||
image_tokens += [image_token]
|
image_tokens += [image_token]
|
||||||
|
|
||||||
if self.is_verbose:
|
|
||||||
token = int(image_token.detach().numpy())
|
|
||||||
print("image token {} is {}".format(i, token))
|
|
||||||
|
|
||||||
return torch.cat(image_tokens)
|
return torch.cat(image_tokens)
|
|
@ -34,12 +34,17 @@ class AttentionFlax(nn.Module):
|
||||||
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
|
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||||
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
|
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
keys: jnp.ndarray,
|
keys: jnp.ndarray,
|
||||||
values: jnp.ndarray,
|
values: jnp.ndarray,
|
||||||
queries: jnp.ndarray,
|
queries: jnp.ndarray,
|
||||||
attention_mask: jnp.ndarray
|
attention_mask: jnp.ndarray
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
|
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
|
||||||
|
values = values.reshape(values.shape[:2] + (self.head_count, -1))
|
||||||
|
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
|
||||||
|
queries /= queries.shape[-1] ** 0.5
|
||||||
attention_bias: jnp.ndarray = lax.select(
|
attention_bias: jnp.ndarray = lax.select(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
jnp.full(attention_mask.shape, 0.0),
|
jnp.full(attention_mask.shape, 0.0),
|
||||||
|
@ -69,11 +74,9 @@ class EncoderSelfAttentionFlax(AttentionFlax):
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
attention_mask: jnp.ndarray
|
attention_mask: jnp.ndarray
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
|
keys = self.k_proj(encoder_state)
|
||||||
keys = self.k_proj(encoder_state).reshape(shape_split)
|
values = self.v_proj(encoder_state)
|
||||||
values = self.v_proj(encoder_state).reshape(shape_split)
|
queries = self.q_proj(encoder_state)
|
||||||
queries = self.q_proj(encoder_state).reshape(shape_split)
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
return self.forward(keys, values, queries, attention_mask)
|
return self.forward(keys, values, queries, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,7 +95,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
|
||||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
attention_mask: jnp.ndarray
|
attention_mask: jnp.ndarray
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
|
@ -120,7 +124,7 @@ class DalleBartEncoderFlax(nn.Module):
|
||||||
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
|
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
|
||||||
self.layers = nn.scan(
|
self.layers = nn.scan(
|
||||||
DalleBartEncoderLayerFlax,
|
DalleBartEncoderLayerFlax,
|
||||||
variable_axes = { "params": 0, "cache": 0 },
|
variable_axes = { "params": 0 },
|
||||||
split_rngs = { "params": True },
|
split_rngs = { "params": True },
|
||||||
in_axes = nn.broadcast,
|
in_axes = nn.broadcast,
|
||||||
length = self.layer_count
|
length = self.layer_count
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||||
torch.no_grad()
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
class GLUTorch(nn.Module):
|
class GLUTorch(nn.Module):
|
||||||
|
@ -34,17 +34,25 @@ class AttentionTorch(nn.Module):
|
||||||
self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
|
self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||||
self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
|
self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||||
self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
|
self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||||
|
self.one = torch.ones((1, 1))
|
||||||
|
if torch.cuda.is_available(): self.one = self.one.cuda()
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
keys: FloatTensor,
|
keys: FloatTensor,
|
||||||
values: FloatTensor,
|
values: FloatTensor,
|
||||||
queries: FloatTensor,
|
queries: FloatTensor,
|
||||||
attention_mask: BoolTensor
|
attention_mask: BoolTensor
|
||||||
) -> FloatTensor:
|
) -> FloatTensor:
|
||||||
|
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
|
||||||
|
values = values.reshape(values.shape[:2] + (self.head_count, -1))
|
||||||
|
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
|
||||||
|
queries /= queries.shape[-1] ** 0.5
|
||||||
|
|
||||||
attention_bias = torch.where(
|
attention_bias = torch.where(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
torch.full(attention_mask.shape, 0.0),
|
self.one * 0,
|
||||||
torch.full(attention_mask.shape, -torch.inf),
|
self.one * (-torch.inf),
|
||||||
)
|
)
|
||||||
attention_weights: FloatTensor = torch.einsum(
|
attention_weights: FloatTensor = torch.einsum(
|
||||||
'bqhc,bkhc->bhqk',
|
'bqhc,bkhc->bhqk',
|
||||||
|
@ -70,11 +78,9 @@ class EncoderSelfAttentionTorch(AttentionTorch):
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
attention_mask: BoolTensor
|
attention_mask: BoolTensor
|
||||||
) -> FloatTensor:
|
) -> FloatTensor:
|
||||||
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
|
keys = self.k_proj.forward(encoder_state)
|
||||||
keys = self.k_proj.forward(encoder_state).reshape(shape_split)
|
values = self.v_proj.forward(encoder_state)
|
||||||
values = self.v_proj.forward(encoder_state).reshape(shape_split)
|
queries = self.q_proj.forward(encoder_state)
|
||||||
queries = self.q_proj.forward(encoder_state).reshape(shape_split)
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
return super().forward(keys, values, queries, attention_mask)
|
return super().forward(keys, values, queries, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,7 +109,8 @@ class EncoderLayerTorch(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DalleBartEncoderTorch(nn.Module):
|
class DalleBartEncoderTorch(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
layer_count: int,
|
layer_count: int,
|
||||||
embed_count: int,
|
embed_count: int,
|
||||||
attention_head_count: int,
|
attention_head_count: int,
|
||||||
|
@ -124,11 +131,14 @@ class DalleBartEncoderTorch(nn.Module):
|
||||||
])
|
])
|
||||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||||
self.final_ln = nn.LayerNorm(embed_count)
|
self.final_ln = nn.LayerNorm(embed_count)
|
||||||
|
self.token_indices = torch.arange(text_token_count).to(torch.long)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.token_indices = self.token_indices.cuda()
|
||||||
|
|
||||||
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
||||||
attention_mask = text_tokens.not_equal(1)
|
attention_mask = text_tokens.not_equal(1)
|
||||||
batch_count, token_count = text_tokens.shape
|
batch_count = text_tokens.shape[0]
|
||||||
pose_tokens = torch.stack([torch.arange(token_count)] * batch_count)
|
pose_tokens = torch.stack([self.token_indices] * batch_count)
|
||||||
encoder_state = (
|
encoder_state = (
|
||||||
self.embed_tokens.forward(text_tokens) +
|
self.embed_tokens.forward(text_tokens) +
|
||||||
self.embed_positions.forward(pose_tokens)
|
self.embed_positions.forward(pose_tokens)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||||
torch.no_grad()
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
BATCH_COUNT: int = 1
|
BATCH_COUNT: int = 1
|
||||||
|
|
||||||
|
@ -61,6 +61,7 @@ class AttentionBlock(Module):
|
||||||
h = self.proj_out.forward(h)
|
h = self.proj_out.forward(h)
|
||||||
return x + h
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
class MiddleLayer(Module):
|
class MiddleLayer(Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -74,6 +75,7 @@ class MiddleLayer(Module):
|
||||||
h = self.block_2.forward(h)
|
h = self.block_2.forward(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class Upsample(Module):
|
class Upsample(Module):
|
||||||
def __init__(self, log2_count):
|
def __init__(self, log2_count):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -86,6 +88,7 @@ class Upsample(Module):
|
||||||
x = self.conv.forward(x)
|
x = self.conv.forward(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class UpsampleBlock(Module):
|
class UpsampleBlock(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -124,6 +127,7 @@ class UpsampleBlock(Module):
|
||||||
h = self.upsample.forward(h)
|
h = self.upsample.forward(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class Decoder(Module):
|
class Decoder(Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -154,6 +158,7 @@ class Decoder(Module):
|
||||||
z = self.conv_out.forward(z)
|
z = self.conv_out.forward(z)
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
class VQGanDetokenizer(Module):
|
class VQGanDetokenizer(Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -8,7 +8,7 @@ class TextTokenizer:
|
||||||
pairs = [tuple(pair.split()) for pair in merges]
|
pairs = [tuple(pair.split()) for pair in merges]
|
||||||
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
|
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
|
||||||
|
|
||||||
def __call__(self, text: str) -> List[int]:
|
def tokenize(self, text: str) -> List[int]:
|
||||||
sep_token = self.token_from_subword['</s>']
|
sep_token = self.token_from_subword['</s>']
|
||||||
cls_token = self.token_from_subword['<s>']
|
cls_token = self.token_from_subword['<s>']
|
||||||
unk_token = self.token_from_subword['<unk>']
|
unk_token = self.token_from_subword['<unk>']
|
||||||
|
|
23
predict.py
Normal file
23
predict.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
import tempfile
|
||||||
|
from cog import BasePredictor, Path, Input
|
||||||
|
|
||||||
|
from min_dalle.min_dalle_torch import MinDalleTorch
|
||||||
|
|
||||||
|
class Predictor(BasePredictor):
|
||||||
|
def setup(self):
|
||||||
|
self.model = MinDalleTorch(is_mega=True)
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
text: str = Input(
|
||||||
|
description="Text for generating images.",
|
||||||
|
),
|
||||||
|
seed: int = Input(
|
||||||
|
description="Specify the seed.",
|
||||||
|
),
|
||||||
|
) -> Path:
|
||||||
|
image = self.model.generate_image(text, seed)
|
||||||
|
out_path = Path(tempfile.mkdtemp()) / "output.png"
|
||||||
|
image.save(str(out_path))
|
||||||
|
|
||||||
|
return out_path
|
1
requirements.txt
vendored
1
requirements.txt
vendored
|
@ -1,2 +1,3 @@
|
||||||
torch
|
torch
|
||||||
flax==0.4.2
|
flax==0.4.2
|
||||||
|
wandb
|
||||||
|
|
10
setup.sh
vendored
10
setup.sh
vendored
|
@ -1,15 +1,15 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
mkdir -p pretrained
|
mkdir -p pretrained/vqgan
|
||||||
|
|
||||||
# download vqgan
|
# download vqgan
|
||||||
git lfs install
|
curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output ./pretrained/vqgan/flax_model.msgpack
|
||||||
git clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384 ./pretrained/vqgan
|
|
||||||
|
|
||||||
# download dalle-mini and dalle mega
|
# download dalle-mini and dalle mega
|
||||||
pip install wandb
|
python -m wandb login --anonymously
|
||||||
python -m wandb login
|
|
||||||
python -m wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
|
python -m wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
|
||||||
python -m wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14
|
python -m wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14
|
||||||
|
|
Loading…
Reference in New Issue
Block a user