Merge branch 'kuprel:main' into replicate
This commit is contained in:
commit
fcc17c895d
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
* linguist-vendored
|
||||||
|
*.py linguist-vendored=false
|
16
README.md
vendored
16
README.md
vendored
|
@ -3,27 +3,31 @@
|
||||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) \
|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) \
|
||||||
Try Replicate web demo here [![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
|
Try Replicate web demo here [![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
|
||||||
|
|
||||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model.
|
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).
|
||||||
|
|
||||||
|
DALL·E Mega inference with PyTorch takes 7.3 seconds in Colab to generate an avocado armchair
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually:
|
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually here:
|
||||||
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
||||||
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
||||||
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
Use the command line python script `image_from_text.py` to generate images. Here are some examples:
|
Use the python script `image_from_text.py` to generate images from the command line. Note: the command line script loads the models and parameters each time. To load a model once and generate multiple times, initialize either `MinDalleTorch` or `MinDalleFlax`, then call `generate_image` with some text and a seed. See the colab for an example.
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
```
|
```
|
||||||
python image_from_text.py --text='alien life' --seed=7
|
python image_from_text.py --text='artificial intelligence' --torch
|
||||||
```
|
```
|
||||||
![Alien](examples/alien.png)
|
![Alien](examples/artificial_intelligence.png)
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=4
|
python image_from_text.py --text='a comfy chair that looks like an avocado' --torch --mega --seed=10
|
||||||
```
|
```
|
||||||
![Avocado Armchair](examples/avocado_armchair.png)
|
![Avocado Armchair](examples/avocado_armchair.png)
|
||||||
|
|
||||||
|
|
BIN
examples/alien.png
vendored
BIN
examples/alien.png
vendored
Binary file not shown.
Before Width: | Height: | Size: 55 KiB |
BIN
examples/artificial_intelligence.png
vendored
Normal file
BIN
examples/artificial_intelligence.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 110 KiB |
BIN
examples/avocado_armchair.png
vendored
BIN
examples/avocado_armchair.png
vendored
Binary file not shown.
Before Width: | Height: | Size: 90 KiB After Width: | Height: | Size: 101 KiB |
|
@ -2,8 +2,8 @@ import argparse
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from min_dalle.generate_image import generate_image_from_text
|
from min_dalle.min_dalle_torch import MinDalleTorch
|
||||||
|
from min_dalle.min_dalle_flax import MinDalleFlax
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--mega', action='store_true')
|
parser.add_argument('--mega', action='store_true')
|
||||||
|
@ -12,10 +12,10 @@ parser.set_defaults(mega=False)
|
||||||
parser.add_argument('--torch', action='store_true')
|
parser.add_argument('--torch', action='store_true')
|
||||||
parser.add_argument('--no-torch', dest='torch', action='store_false')
|
parser.add_argument('--no-torch', dest='torch', action='store_false')
|
||||||
parser.set_defaults(torch=False)
|
parser.set_defaults(torch=False)
|
||||||
parser.add_argument('--text', type=str)
|
parser.add_argument('--text', type=str, default='alien life')
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=7)
|
||||||
parser.add_argument('--image_path', type=str, default='generated')
|
parser.add_argument('--image_path', type=str, default='generated')
|
||||||
parser.add_argument('--image_token_count', type=int, default=256) # for debugging
|
parser.add_argument('--sample_token_count', type=int, default=256) # for debugging
|
||||||
|
|
||||||
|
|
||||||
def ascii_from_image(image: Image.Image, size: int) -> str:
|
def ascii_from_image(image: Image.Image, size: int) -> str:
|
||||||
|
@ -36,19 +36,40 @@ def save_image(image: Image.Image, path: str):
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image(
|
||||||
|
is_torch: bool,
|
||||||
|
is_mega: bool,
|
||||||
|
text: str,
|
||||||
|
seed: int,
|
||||||
|
image_path: str,
|
||||||
|
sample_token_count: int
|
||||||
|
):
|
||||||
|
if is_torch:
|
||||||
|
image_generator = MinDalleTorch(is_mega, sample_token_count)
|
||||||
|
image_tokens = image_generator.generate_image_tokens(text, seed)
|
||||||
|
|
||||||
|
if sample_token_count < image_generator.config['image_length']:
|
||||||
|
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
image = image_generator.generate_image(text, seed)
|
||||||
|
|
||||||
|
else:
|
||||||
|
image_generator = MinDalleFlax(is_mega)
|
||||||
|
image = image_generator.generate_image(text, seed)
|
||||||
|
|
||||||
|
save_image(image, image_path)
|
||||||
|
print(ascii_from_image(image, size=128))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print(args)
|
print(args)
|
||||||
|
generate_image(
|
||||||
image = generate_image_from_text(
|
is_torch=args.torch,
|
||||||
text = args.text,
|
is_mega=args.mega,
|
||||||
is_mega = args.mega,
|
text=args.text,
|
||||||
is_torch = args.torch,
|
seed=args.seed,
|
||||||
seed = args.seed,
|
image_path=args.image_path,
|
||||||
image_token_count = args.image_token_count
|
sample_token_count=args.sample_token_count
|
||||||
)
|
)
|
||||||
|
|
||||||
if image != None:
|
|
||||||
save_image(image, args.image_path)
|
|
||||||
print(ascii_from_image(image, size=128))
|
|
92
min_dalle.ipynb
vendored
92
min_dalle.ipynb
vendored
File diff suppressed because one or more lines are too long
|
@ -1,78 +0,0 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import numpy
|
|
||||||
from PIL import Image
|
|
||||||
from typing import Tuple, List
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from min_dalle.load_params import load_dalle_bart_flax_params
|
|
||||||
from min_dalle.text_tokenizer import TextTokenizer
|
|
||||||
from min_dalle.min_dalle_flax import generate_image_tokens_flax
|
|
||||||
from min_dalle.min_dalle_torch import (
|
|
||||||
generate_image_tokens_torch,
|
|
||||||
detokenize_torch
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
|
|
||||||
print("parsing metadata from {}".format(path))
|
|
||||||
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
|
|
||||||
assert(os.path.exists(os.path.join(path, f)))
|
|
||||||
with open(path + '/config.json', 'r') as f:
|
|
||||||
config = json.load(f)
|
|
||||||
with open(path + '/vocab.json') as f:
|
|
||||||
vocab = json.load(f)
|
|
||||||
with open(path + '/merges.txt') as f:
|
|
||||||
merges = f.read().split("\n")[1:-1]
|
|
||||||
return config, vocab, merges
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_text(
|
|
||||||
text: str,
|
|
||||||
config: dict,
|
|
||||||
vocab: dict,
|
|
||||||
merges: List[str]
|
|
||||||
) -> numpy.ndarray:
|
|
||||||
print("tokenizing text")
|
|
||||||
tokens = TextTokenizer(vocab, merges)(text)
|
|
||||||
print("text tokens", tokens)
|
|
||||||
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
|
|
||||||
text_tokens[0, :len(tokens)] = tokens
|
|
||||||
text_tokens[1, :2] = [tokens[0], tokens[-1]]
|
|
||||||
return text_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image_from_text(
|
|
||||||
text: str,
|
|
||||||
is_mega: bool = False,
|
|
||||||
is_torch: bool = False,
|
|
||||||
seed: int = 0,
|
|
||||||
image_token_count: int = 256
|
|
||||||
) -> Image.Image:
|
|
||||||
model_name = 'mega' if is_mega else 'mini'
|
|
||||||
model_path = './pretrained/dalle_bart_{}'.format(model_name)
|
|
||||||
config, vocab, merges = load_dalle_bart_metadata(model_path)
|
|
||||||
text_tokens = tokenize_text(text, config, vocab, merges)
|
|
||||||
params_dalle_bart = load_dalle_bart_flax_params(model_path)
|
|
||||||
|
|
||||||
if is_torch:
|
|
||||||
image_tokens = generate_image_tokens_torch(
|
|
||||||
text_tokens = text_tokens,
|
|
||||||
seed = seed,
|
|
||||||
config = config,
|
|
||||||
params = params_dalle_bart,
|
|
||||||
image_token_count = image_token_count
|
|
||||||
)
|
|
||||||
if image_token_count == config['image_length']:
|
|
||||||
image = detokenize_torch(image_tokens, is_torch=True)
|
|
||||||
return Image.fromarray(image)
|
|
||||||
else:
|
|
||||||
print(list(image_tokens.to('cpu').detach().numpy()))
|
|
||||||
else:
|
|
||||||
image_tokens = generate_image_tokens_flax(
|
|
||||||
text_tokens = text_tokens,
|
|
||||||
seed = seed,
|
|
||||||
config = config,
|
|
||||||
params = params_dalle_bart,
|
|
||||||
)
|
|
||||||
image = detokenize_torch(torch.tensor(image_tokens), is_torch=False)
|
|
||||||
return Image.fromarray(image)
|
|
43
min_dalle/min_dalle.py
Normal file
43
min_dalle/min_dalle.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
from .text_tokenizer import TextTokenizer
|
||||||
|
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
|
||||||
|
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||||
|
|
||||||
|
class MinDalle:
|
||||||
|
def __init__(self, is_mega: bool):
|
||||||
|
self.is_mega = is_mega
|
||||||
|
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
||||||
|
model_path = os.path.join('pretrained', model_name)
|
||||||
|
|
||||||
|
print("reading files from {}".format(model_path))
|
||||||
|
config_path = os.path.join(model_path, 'config.json')
|
||||||
|
vocab_path = os.path.join(model_path, 'vocab.json')
|
||||||
|
merges_path = os.path.join(model_path, 'merges.txt')
|
||||||
|
|
||||||
|
with open(config_path, 'r', encoding='utf8') as f:
|
||||||
|
self.config = json.load(f)
|
||||||
|
with open(vocab_path, 'r', encoding='utf8') as f:
|
||||||
|
vocab = json.load(f)
|
||||||
|
with open(merges_path, 'r', encoding='utf8') as f:
|
||||||
|
merges = f.read().split("\n")[1:-1]
|
||||||
|
|
||||||
|
self.model_params = load_dalle_bart_flax_params(model_path)
|
||||||
|
|
||||||
|
self.tokenizer = TextTokenizer(vocab, merges)
|
||||||
|
self.detokenizer = VQGanDetokenizer()
|
||||||
|
vqgan_params = load_vqgan_torch_params('./pretrained/vqgan')
|
||||||
|
self.detokenizer.load_state_dict(vqgan_params)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_text(self, text: str) -> numpy.ndarray:
|
||||||
|
print("tokenizing text")
|
||||||
|
tokens = self.tokenizer.tokenize(text)
|
||||||
|
print("text tokens", tokens)
|
||||||
|
text_token_count = self.config['max_text_length']
|
||||||
|
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
|
||||||
|
text_tokens[0, :len(tokens)] = tokens
|
||||||
|
text_tokens[1, :2] = [tokens[0], tokens[-1]]
|
||||||
|
return text_tokens
|
|
@ -1,79 +1,58 @@
|
||||||
import jax
|
import jax
|
||||||
from jax import numpy as jnp
|
|
||||||
import numpy
|
import numpy
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .min_dalle import MinDalle
|
||||||
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
||||||
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
||||||
|
|
||||||
|
|
||||||
def encode_flax(
|
class MinDalleFlax(MinDalle):
|
||||||
text_tokens: numpy.ndarray,
|
def __init__(self, is_mega: bool):
|
||||||
config: dict,
|
super().__init__(is_mega)
|
||||||
params: dict
|
print("initializing MinDalleFlax")
|
||||||
) -> jnp.ndarray:
|
|
||||||
print("loading flax encoder")
|
print("loading encoder")
|
||||||
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
|
self.encoder = DalleBartEncoderFlax(
|
||||||
attention_head_count = config['encoder_attention_heads'],
|
attention_head_count = self.config['encoder_attention_heads'],
|
||||||
embed_count = config['d_model'],
|
embed_count = self.config['d_model'],
|
||||||
glu_embed_count = config['encoder_ffn_dim'],
|
glu_embed_count = self.config['encoder_ffn_dim'],
|
||||||
text_token_count = config['max_text_length'],
|
text_token_count = self.config['max_text_length'],
|
||||||
text_vocab_count = config['encoder_vocab_size'],
|
text_vocab_count = self.config['encoder_vocab_size'],
|
||||||
layer_count = config['encoder_layers']
|
layer_count = self.config['encoder_layers']
|
||||||
).bind({'params': params.pop('encoder')})
|
).bind({'params': self.model_params.pop('encoder')})
|
||||||
|
|
||||||
|
print("loading decoder")
|
||||||
|
self.decoder = DalleBartDecoderFlax(
|
||||||
|
image_token_count = self.config['image_length'],
|
||||||
|
text_token_count = self.config['max_text_length'],
|
||||||
|
image_vocab_count = self.config['image_vocab_size'],
|
||||||
|
attention_head_count = self.config['decoder_attention_heads'],
|
||||||
|
embed_count = self.config['d_model'],
|
||||||
|
glu_embed_count = self.config['decoder_ffn_dim'],
|
||||||
|
layer_count = self.config['decoder_layers'],
|
||||||
|
start_token = self.config['decoder_start_token_id']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image(self, text: str, seed: int) -> Image.Image:
|
||||||
|
text_tokens = self.tokenize_text(text)
|
||||||
|
|
||||||
print("encoding text tokens")
|
print("encoding text tokens")
|
||||||
encoder_state = encoder(text_tokens)
|
encoder_state = self.encoder(text_tokens)
|
||||||
del encoder
|
|
||||||
return encoder_state
|
|
||||||
|
|
||||||
|
|
||||||
def decode_flax(
|
|
||||||
text_tokens: jnp.ndarray,
|
|
||||||
encoder_state: jnp.ndarray,
|
|
||||||
config: dict,
|
|
||||||
seed: int,
|
|
||||||
params: dict
|
|
||||||
) -> jnp.ndarray:
|
|
||||||
print("loading flax decoder")
|
|
||||||
decoder = DalleBartDecoderFlax(
|
|
||||||
image_token_count = config['image_length'],
|
|
||||||
text_token_count = config['max_text_length'],
|
|
||||||
image_vocab_count = config['image_vocab_size'],
|
|
||||||
attention_head_count = config['decoder_attention_heads'],
|
|
||||||
embed_count = config['d_model'],
|
|
||||||
glu_embed_count = config['decoder_ffn_dim'],
|
|
||||||
layer_count = config['decoder_layers'],
|
|
||||||
start_token = config['decoder_start_token_id']
|
|
||||||
)
|
|
||||||
print("sampling image tokens")
|
print("sampling image tokens")
|
||||||
image_tokens = decoder.sample_image_tokens(
|
image_tokens = self.decoder.sample_image_tokens(
|
||||||
text_tokens,
|
text_tokens,
|
||||||
encoder_state,
|
encoder_state,
|
||||||
jax.random.PRNGKey(seed),
|
jax.random.PRNGKey(seed),
|
||||||
params.pop('decoder')
|
self.model_params['decoder']
|
||||||
)
|
)
|
||||||
del decoder
|
|
||||||
return image_tokens
|
|
||||||
|
|
||||||
|
image_tokens = torch.tensor(numpy.array(image_tokens))
|
||||||
|
|
||||||
def generate_image_tokens_flax(
|
print("detokenizing image")
|
||||||
text_tokens: numpy.ndarray,
|
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
|
||||||
seed: int,
|
image = Image.fromarray(image.to('cpu').detach().numpy())
|
||||||
config: dict,
|
return image
|
||||||
params: dict
|
|
||||||
) -> numpy.ndarray:
|
|
||||||
encoder_state = encode_flax(
|
|
||||||
text_tokens,
|
|
||||||
config,
|
|
||||||
params
|
|
||||||
)
|
|
||||||
image_tokens = decode_flax(
|
|
||||||
text_tokens,
|
|
||||||
encoder_state,
|
|
||||||
config,
|
|
||||||
seed,
|
|
||||||
params
|
|
||||||
)
|
|
||||||
image_tokens = numpy.array(image_tokens)
|
|
||||||
print("image tokens", list(image_tokens))
|
|
||||||
return image_tokens
|
|
|
@ -1,118 +1,83 @@
|
||||||
|
from random import sample
|
||||||
import numpy
|
import numpy
|
||||||
import os
|
import os
|
||||||
|
from PIL import Image
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from torch import LongTensor, FloatTensor
|
from torch import LongTensor
|
||||||
import torch
|
import torch
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_num_threads(os.cpu_count())
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
|
||||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
from .load_params import convert_dalle_bart_torch_from_flax_params
|
||||||
|
from .min_dalle import MinDalle
|
||||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||||
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||||
|
|
||||||
from .load_params import (
|
|
||||||
load_vqgan_torch_params,
|
|
||||||
convert_dalle_bart_torch_from_flax_params
|
|
||||||
)
|
|
||||||
|
|
||||||
|
class MinDalleTorch(MinDalle):
|
||||||
|
def __init__(self, is_mega: bool, sample_token_count: int = 256):
|
||||||
|
super().__init__(is_mega)
|
||||||
|
print("initializing MinDalleTorch")
|
||||||
|
|
||||||
def encode_torch(
|
print("loading encoder")
|
||||||
text_tokens: LongTensor,
|
self.encoder = DalleBartEncoderTorch(
|
||||||
config: dict,
|
layer_count = self.config['encoder_layers'],
|
||||||
params: dict
|
embed_count = self.config['d_model'],
|
||||||
) -> FloatTensor:
|
attention_head_count = self.config['encoder_attention_heads'],
|
||||||
print("loading torch encoder")
|
text_vocab_count = self.config['encoder_vocab_size'],
|
||||||
encoder = DalleBartEncoderTorch(
|
text_token_count = self.config['max_text_length'],
|
||||||
layer_count = config['encoder_layers'],
|
glu_embed_count = self.config['encoder_ffn_dim']
|
||||||
embed_count = config['d_model'],
|
|
||||||
attention_head_count = config['encoder_attention_heads'],
|
|
||||||
text_vocab_count = config['encoder_vocab_size'],
|
|
||||||
text_token_count = config['max_text_length'],
|
|
||||||
glu_embed_count = config['encoder_ffn_dim']
|
|
||||||
)
|
)
|
||||||
encoder_params = convert_dalle_bart_torch_from_flax_params(
|
encoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||||
params.pop('encoder'),
|
self.model_params.pop('encoder'),
|
||||||
layer_count=config['encoder_layers'],
|
layer_count=self.config['encoder_layers'],
|
||||||
is_encoder=True
|
is_encoder=True
|
||||||
)
|
)
|
||||||
encoder.load_state_dict(encoder_params, strict=False)
|
self.encoder.load_state_dict(encoder_params, strict=False)
|
||||||
del encoder_params
|
|
||||||
if torch.cuda.is_available(): encoder = encoder.cuda()
|
|
||||||
|
|
||||||
print("encoding text tokens")
|
print("loading decoder")
|
||||||
encoder_state = encoder(text_tokens)
|
self.decoder = DalleBartDecoderTorch(
|
||||||
del encoder
|
image_vocab_size = self.config['image_vocab_size'],
|
||||||
return encoder_state
|
image_token_count = self.config['image_length'],
|
||||||
|
sample_token_count = sample_token_count,
|
||||||
|
embed_count = self.config['d_model'],
|
||||||
def decode_torch(
|
attention_head_count = self.config['decoder_attention_heads'],
|
||||||
text_tokens: LongTensor,
|
glu_embed_count = self.config['decoder_ffn_dim'],
|
||||||
encoder_state: FloatTensor,
|
layer_count = self.config['decoder_layers'],
|
||||||
config: dict,
|
|
||||||
seed: int,
|
|
||||||
params: dict,
|
|
||||||
image_token_count: int
|
|
||||||
) -> LongTensor:
|
|
||||||
print("loading torch decoder")
|
|
||||||
decoder = DalleBartDecoderTorch(
|
|
||||||
image_vocab_size = config['image_vocab_size'],
|
|
||||||
image_token_count = config['image_length'],
|
|
||||||
sample_token_count = image_token_count,
|
|
||||||
embed_count = config['d_model'],
|
|
||||||
attention_head_count = config['decoder_attention_heads'],
|
|
||||||
glu_embed_count = config['decoder_ffn_dim'],
|
|
||||||
layer_count = config['decoder_layers'],
|
|
||||||
batch_count = 2,
|
batch_count = 2,
|
||||||
start_token = config['decoder_start_token_id'],
|
start_token = self.config['decoder_start_token_id'],
|
||||||
is_verbose = True
|
is_verbose = True
|
||||||
)
|
)
|
||||||
decoder_params = convert_dalle_bart_torch_from_flax_params(
|
decoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||||
params.pop('decoder'),
|
self.model_params.pop('decoder'),
|
||||||
layer_count=config['decoder_layers'],
|
layer_count=self.config['decoder_layers'],
|
||||||
is_encoder=False
|
is_encoder=False
|
||||||
)
|
)
|
||||||
decoder.load_state_dict(decoder_params, strict=False)
|
self.decoder.load_state_dict(decoder_params, strict=False)
|
||||||
del decoder_params
|
|
||||||
if torch.cuda.is_available(): decoder = decoder.cuda()
|
if torch.cuda.is_available():
|
||||||
|
self.encoder = self.encoder.cuda()
|
||||||
|
self.decoder = self.decoder.cuda()
|
||||||
|
self.detokenizer = self.detokenizer.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
||||||
|
text_tokens = self.tokenize_text(text)
|
||||||
|
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||||
|
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||||
|
|
||||||
|
print("encoding text tokens")
|
||||||
|
encoder_state = self.encoder.forward(text_tokens)
|
||||||
|
|
||||||
print("sampling image tokens")
|
print("sampling image tokens")
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
image_tokens = decoder.forward(text_tokens, encoder_state)
|
image_tokens = self.decoder.forward(text_tokens, encoder_state)
|
||||||
return image_tokens
|
return image_tokens
|
||||||
|
|
||||||
|
|
||||||
def generate_image_tokens_torch(
|
def generate_image(self, text: str, seed: int) -> Image.Image:
|
||||||
text_tokens: numpy.ndarray,
|
image_tokens = self.generate_image_tokens(text, seed)
|
||||||
seed: int,
|
|
||||||
config: dict,
|
|
||||||
params: dict,
|
|
||||||
image_token_count: int
|
|
||||||
) -> LongTensor:
|
|
||||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
|
||||||
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
|
||||||
encoder_state = encode_torch(
|
|
||||||
text_tokens,
|
|
||||||
config,
|
|
||||||
params
|
|
||||||
)
|
|
||||||
image_tokens = decode_torch(
|
|
||||||
text_tokens,
|
|
||||||
encoder_state,
|
|
||||||
config,
|
|
||||||
seed,
|
|
||||||
params,
|
|
||||||
image_token_count
|
|
||||||
)
|
|
||||||
return image_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray:
|
|
||||||
print("detokenizing image")
|
print("detokenizing image")
|
||||||
model_path = './pretrained/vqgan'
|
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
|
||||||
params = load_vqgan_torch_params(model_path)
|
image = Image.fromarray(image.to('cpu').detach().numpy())
|
||||||
detokenizer = VQGanDetokenizer()
|
return image
|
||||||
detokenizer.load_state_dict(params)
|
|
||||||
if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda()
|
|
||||||
image = detokenizer.forward(image_tokens).to(torch.uint8)
|
|
||||||
del detokenizer, params
|
|
||||||
return image.to('cpu').detach().numpy()
|
|
|
@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax):
|
||||||
|
|
||||||
|
|
||||||
class DecoderSelfAttentionFlax(AttentionFlax):
|
class DecoderSelfAttentionFlax(AttentionFlax):
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
decoder_state: jnp.ndarray,
|
decoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
keys_state: jnp.ndarray,
|
||||||
values_state: jnp.ndarray,
|
values_state: jnp.ndarray,
|
||||||
|
@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
||||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
decoder_state: jnp.ndarray,
|
decoder_state: jnp.ndarray,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
keys_state: jnp.ndarray,
|
||||||
|
@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
self.final_ln = nn.LayerNorm(use_scale=False)
|
self.final_ln = nn.LayerNorm(use_scale=False)
|
||||||
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
|
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
|
||||||
|
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
keys_state: jnp.ndarray,
|
||||||
values_state: jnp.ndarray,
|
values_state: jnp.ndarray,
|
||||||
|
@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
decoder_state = self.lm_head(decoder_state)
|
decoder_state = self.lm_head(decoder_state)
|
||||||
return decoder_state, keys_state, values_state
|
return decoder_state, keys_state, values_state
|
||||||
|
|
||||||
def sample_image_tokens(self,
|
def sample_image_tokens(
|
||||||
|
self,
|
||||||
text_tokens: jnp.ndarray,
|
text_tokens: jnp.ndarray,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
prng_key: jax.random.PRNGKey,
|
prng_key: jax.random.PRNGKey,
|
||||||
|
|
|
@ -16,40 +16,34 @@ class DecoderCrossAttentionTorch(AttentionTorch):
|
||||||
keys = self.k_proj.forward(encoder_state)
|
keys = self.k_proj.forward(encoder_state)
|
||||||
values = self.v_proj.forward(encoder_state)
|
values = self.v_proj.forward(encoder_state)
|
||||||
queries = self.q_proj.forward(decoder_state)
|
queries = self.q_proj.forward(decoder_state)
|
||||||
query_shape = queries.shape[:2] + (self.head_count, -1)
|
|
||||||
key_value_shape = keys.shape[:2] + (self.head_count, -1)
|
|
||||||
keys = keys.reshape(key_value_shape)
|
|
||||||
values = values.reshape(key_value_shape)
|
|
||||||
queries = queries.reshape(query_shape)
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
return super().forward(keys, values, queries, attention_mask)
|
return super().forward(keys, values, queries, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
class DecoderSelfAttentionTorch(AttentionTorch):
|
class DecoderSelfAttentionTorch(AttentionTorch):
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
decoder_state: FloatTensor,
|
decoder_state: FloatTensor,
|
||||||
keys_values: FloatTensor,
|
keys_values: FloatTensor,
|
||||||
attention_mask: BoolTensor,
|
attention_mask: BoolTensor,
|
||||||
token_mask: BoolTensor
|
token_mask: BoolTensor
|
||||||
) -> Tuple[FloatTensor, FloatTensor]:
|
) -> Tuple[FloatTensor, FloatTensor]:
|
||||||
batch_count = decoder_state.shape[0]
|
batch_count = decoder_state.shape[0]
|
||||||
shape = (batch_count, 1) + keys_values.shape[2:]
|
keys = self.k_proj.forward(decoder_state)
|
||||||
keys = self.k_proj.forward(decoder_state).view(shape)
|
values = self.v_proj.forward(decoder_state)
|
||||||
values = self.v_proj.forward(decoder_state).view(shape)
|
queries = self.q_proj.forward(decoder_state)
|
||||||
keys_values = torch.where(
|
keys_values = torch.where(
|
||||||
token_mask[None, :, None, None],
|
token_mask[None, :, None],
|
||||||
torch.cat([keys, values]),
|
torch.cat([keys, values]),
|
||||||
keys_values
|
keys_values
|
||||||
)
|
)
|
||||||
queries = self.q_proj.forward(decoder_state).reshape(shape)
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
keys, values = keys_values[:batch_count], keys_values[batch_count:]
|
keys, values = keys_values[:batch_count], keys_values[batch_count:]
|
||||||
decoder_state = super().forward(keys, values, queries, attention_mask)
|
decoder_state = super().forward(keys, values, queries, attention_mask)
|
||||||
return decoder_state, keys_values
|
return decoder_state, keys_values
|
||||||
|
|
||||||
|
|
||||||
class DecoderLayerTorch(nn.Module):
|
class DecoderLayerTorch(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
image_token_count: int,
|
image_token_count: int,
|
||||||
head_count: int,
|
head_count: int,
|
||||||
embed_count: int,
|
embed_count: int,
|
||||||
|
@ -69,7 +63,8 @@ class DecoderLayerTorch(nn.Module):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.token_indices = self.token_indices.cuda()
|
self.token_indices = self.token_indices.cuda()
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
decoder_state: FloatTensor,
|
decoder_state: FloatTensor,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
keys_values_state: FloatTensor,
|
keys_values_state: FloatTensor,
|
||||||
|
@ -111,7 +106,8 @@ class DecoderLayerTorch(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DalleBartDecoderTorch(nn.Module):
|
class DalleBartDecoderTorch(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
image_vocab_size: int,
|
image_vocab_size: int,
|
||||||
image_token_count: int,
|
image_token_count: int,
|
||||||
sample_token_count: int,
|
sample_token_count: int,
|
||||||
|
@ -146,8 +142,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
self.keys_values_state_shape = (
|
self.keys_values_state_shape = (
|
||||||
layer_count * 2 * batch_count,
|
layer_count * 2 * batch_count,
|
||||||
image_token_count,
|
image_token_count,
|
||||||
attention_head_count,
|
embed_count
|
||||||
embed_count // attention_head_count
|
|
||||||
)
|
)
|
||||||
self.zero_prob = torch.zeros([1])
|
self.zero_prob = torch.zeros([1])
|
||||||
self.token_indices = torch.arange(self.sample_token_count)
|
self.token_indices = torch.arange(self.sample_token_count)
|
||||||
|
@ -158,7 +153,8 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
self.start_token = self.start_token.cuda()
|
self.start_token = self.start_token.cuda()
|
||||||
|
|
||||||
|
|
||||||
def decode_step(self,
|
def decode_step(
|
||||||
|
self,
|
||||||
text_tokens: LongTensor,
|
text_tokens: LongTensor,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
keys_values_state: FloatTensor,
|
keys_values_state: FloatTensor,
|
||||||
|
@ -183,7 +179,6 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
token_index[:1]
|
token_index[:1]
|
||||||
)
|
)
|
||||||
keys_values.append(keys_values_layer)
|
keys_values.append(keys_values_layer)
|
||||||
keys_values = torch.cat(keys_values, dim=0)
|
|
||||||
decoder_state = self.final_ln(decoder_state)
|
decoder_state = self.final_ln(decoder_state)
|
||||||
logits = self.lm_head(decoder_state)
|
logits = self.lm_head(decoder_state)
|
||||||
a = self.condition_factor
|
a = self.condition_factor
|
||||||
|
@ -195,10 +190,11 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
self.zero_prob,
|
self.zero_prob,
|
||||||
torch.exp(logits - top_logits[0])
|
torch.exp(logits - top_logits[0])
|
||||||
)
|
)
|
||||||
return probs, keys_values
|
return probs, torch.cat(keys_values)
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
text_tokens: LongTensor,
|
text_tokens: LongTensor,
|
||||||
encoder_state: FloatTensor
|
encoder_state: FloatTensor
|
||||||
) -> LongTensor:
|
) -> LongTensor:
|
||||||
|
|
|
@ -34,7 +34,8 @@ class AttentionFlax(nn.Module):
|
||||||
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
|
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||||
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
|
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
keys: jnp.ndarray,
|
keys: jnp.ndarray,
|
||||||
values: jnp.ndarray,
|
values: jnp.ndarray,
|
||||||
queries: jnp.ndarray,
|
queries: jnp.ndarray,
|
||||||
|
@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
|
||||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
attention_mask: jnp.ndarray
|
attention_mask: jnp.ndarray
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
|
|
|
@ -37,12 +37,18 @@ class AttentionTorch(nn.Module):
|
||||||
self.one = torch.ones((1, 1))
|
self.one = torch.ones((1, 1))
|
||||||
if torch.cuda.is_available(): self.one = self.one.cuda()
|
if torch.cuda.is_available(): self.one = self.one.cuda()
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
keys: FloatTensor,
|
keys: FloatTensor,
|
||||||
values: FloatTensor,
|
values: FloatTensor,
|
||||||
queries: FloatTensor,
|
queries: FloatTensor,
|
||||||
attention_mask: BoolTensor
|
attention_mask: BoolTensor
|
||||||
) -> FloatTensor:
|
) -> FloatTensor:
|
||||||
|
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
|
||||||
|
values = values.reshape(values.shape[:2] + (self.head_count, -1))
|
||||||
|
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
|
||||||
|
queries /= queries.shape[-1] ** 0.5
|
||||||
|
|
||||||
attention_bias = torch.where(
|
attention_bias = torch.where(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
self.one * 0,
|
self.one * 0,
|
||||||
|
@ -72,11 +78,9 @@ class EncoderSelfAttentionTorch(AttentionTorch):
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
attention_mask: BoolTensor
|
attention_mask: BoolTensor
|
||||||
) -> FloatTensor:
|
) -> FloatTensor:
|
||||||
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
|
keys = self.k_proj.forward(encoder_state)
|
||||||
keys = self.k_proj.forward(encoder_state).reshape(shape_split)
|
values = self.v_proj.forward(encoder_state)
|
||||||
values = self.v_proj.forward(encoder_state).reshape(shape_split)
|
queries = self.q_proj.forward(encoder_state)
|
||||||
queries = self.q_proj.forward(encoder_state).reshape(shape_split)
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
return super().forward(keys, values, queries, attention_mask)
|
return super().forward(keys, values, queries, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,7 +109,8 @@ class EncoderLayerTorch(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DalleBartEncoderTorch(nn.Module):
|
class DalleBartEncoderTorch(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
layer_count: int,
|
layer_count: int,
|
||||||
embed_count: int,
|
embed_count: int,
|
||||||
attention_head_count: int,
|
attention_head_count: int,
|
||||||
|
|
|
@ -8,7 +8,7 @@ class TextTokenizer:
|
||||||
pairs = [tuple(pair.split()) for pair in merges]
|
pairs = [tuple(pair.split()) for pair in merges]
|
||||||
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
|
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
|
||||||
|
|
||||||
def __call__(self, text: str) -> List[int]:
|
def tokenize(self, text: str) -> List[int]:
|
||||||
sep_token = self.token_from_subword['</s>']
|
sep_token = self.token_from_subword['</s>']
|
||||||
cls_token = self.token_from_subword['<s>']
|
cls_token = self.token_from_subword['<s>']
|
||||||
unk_token = self.token_from_subword['<unk>']
|
unk_token = self.token_from_subword['<unk>']
|
||||||
|
|
2
requirements.txt
vendored
2
requirements.txt
vendored
|
@ -1,3 +1,3 @@
|
||||||
torch
|
torch
|
||||||
flax==0.4.2
|
flax==0.5.2
|
||||||
wandb
|
wandb
|
Loading…
Reference in New Issue
Block a user