refactored to load models once and run multiple times

main
Brett Kuprel 2 years ago
parent 1ef9b0b929
commit ed91ab4a30
  1. 10
      README.md
  2. 53
      image_from_text.py
  3. 78
      min_dalle/generate_image.py
  4. 38
      min_dalle/min_dalle.py
  5. 123
      min_dalle/min_dalle_flax.py
  6. 163
      min_dalle/min_dalle_torch.py
  7. 12
      min_dalle/models/dalle_bart_decoder_flax.py
  8. 18
      min_dalle/models/dalle_bart_decoder_torch.py
  9. 6
      min_dalle/models/dalle_bart_encoder_flax.py
  10. 6
      min_dalle/models/dalle_bart_encoder_torch.py
  11. 2
      min_dalle/text_tokenizer.py

@ -2,18 +2,18 @@
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model.
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy`, `torch`, and `flax`.
### Setup
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually:
Run `sh setup.sh` to install dependencies and download pretrained models. The `wandb` python package is installed to download DALL·E mini and DALL·E mega. Alternatively, the models can be downloaded manually here:
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
### Usage
Use the command line python script `image_from_text.py` to generate images. Here are some examples:
The simplest way to get started is the command line python script `image_from_text.py` provided. Here are some examples runs:
```
python image_from_text.py --text='alien life' --seed=7
@ -32,3 +32,7 @@ python image_from_text.py --text='court sketch of godzilla on trial' --mega --se
```
![Godzilla Trial](examples/godzilla_trial.png)
### Load once run multiple times
The command line script loads the models and parameters each time. The colab notebook demonstrates how to load the models once and run multiple times.

@ -2,8 +2,8 @@ import argparse
import os
from PIL import Image
from min_dalle.generate_image import generate_image_from_text
from min_dalle.min_dalle_torch import MinDalleTorch
from min_dalle.min_dalle_flax import MinDalleFlax
parser = argparse.ArgumentParser()
parser.add_argument('--mega', action='store_true')
@ -15,7 +15,7 @@ parser.set_defaults(torch=False)
parser.add_argument('--text', type=str)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--image_path', type=str, default='generated')
parser.add_argument('--image_token_count', type=int, default=256) # for debugging
parser.add_argument('--sample_token_count', type=int, default=256) # for debugging
def ascii_from_image(image: Image.Image, size: int) -> str:
@ -36,19 +36,40 @@ def save_image(image: Image.Image, path: str):
return image
def generate_image(
is_torch: bool,
is_mega: bool,
text: str,
seed: int,
image_path: str,
sample_token_count: int
):
if is_torch:
image_generator = MinDalleTorch(is_mega, sample_token_count)
image_tokens = image_generator.generate_image_tokens(text, seed)
if sample_token_count < image_generator.config['image_length']:
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
return
else:
image = image_generator.generate_image(text, seed)
else:
image_generator = MinDalleFlax(is_mega)
image = image_generator.generate_image(text, seed)
save_image(image, image_path)
print(ascii_from_image(image, size=128))
if __name__ == '__main__':
args = parser.parse_args()
print(args)
image = generate_image_from_text(
text = args.text,
is_mega = args.mega,
is_torch = args.torch,
seed = args.seed,
image_token_count = args.image_token_count
)
if image != None:
save_image(image, args.image_path)
print(ascii_from_image(image, size=128))
generate_image(
is_torch=args.torch,
is_mega=args.mega,
text=args.text,
seed=args.seed,
image_path=args.image_path,
sample_token_count=args.sample_token_count
)

@ -1,78 +0,0 @@
import os
import json
import numpy
from PIL import Image
from typing import Tuple, List
import torch
from min_dalle.load_params import load_dalle_bart_flax_params
from min_dalle.text_tokenizer import TextTokenizer
from min_dalle.min_dalle_flax import generate_image_tokens_flax
from min_dalle.min_dalle_torch import (
generate_image_tokens_torch,
detokenize_torch
)
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
print("parsing metadata from {}".format(path))
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
assert(os.path.exists(os.path.join(path, f)))
with open(path + '/config.json', 'r') as f:
config = json.load(f)
with open(path + '/vocab.json') as f:
vocab = json.load(f)
with open(path + '/merges.txt') as f:
merges = f.read().split("\n")[1:-1]
return config, vocab, merges
def tokenize_text(
text: str,
config: dict,
vocab: dict,
merges: List[str]
) -> numpy.ndarray:
print("tokenizing text")
tokens = TextTokenizer(vocab, merges)(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens
def generate_image_from_text(
text: str,
is_mega: bool = False,
is_torch: bool = False,
seed: int = 0,
image_token_count: int = 256
) -> Image.Image:
model_name = 'mega' if is_mega else 'mini'
model_path = './pretrained/dalle_bart_{}'.format(model_name)
config, vocab, merges = load_dalle_bart_metadata(model_path)
text_tokens = tokenize_text(text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(model_path)
if is_torch:
image_tokens = generate_image_tokens_torch(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
image_token_count = image_token_count
)
if image_token_count == config['image_length']:
image = detokenize_torch(image_tokens, is_torch=True)
return Image.fromarray(image)
else:
print(list(image_tokens.to('cpu').detach().numpy()))
else:
image_tokens = generate_image_tokens_flax(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
)
image = detokenize_torch(torch.tensor(image_tokens), is_torch=False)
return Image.fromarray(image)

@ -0,0 +1,38 @@
import os
import json
import numpy
from .text_tokenizer import TextTokenizer
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
from .models.vqgan_detokenizer import VQGanDetokenizer
class MinDalle:
def __init__(self, is_mega: bool):
self.is_mega = is_mega
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(model_path))
with open(os.path.join(model_path, 'config.json'), 'r') as f:
self.config = json.load(f)
with open(os.path.join(model_path, 'vocab.json'), 'r') as f:
vocab = json.load(f)
with open(os.path.join(model_path, 'merges.txt'), 'r') as f:
merges = f.read().split("\n")[1:-1]
self.model_params = load_dalle_bart_flax_params(model_path)
self.tokenizer = TextTokenizer(vocab, merges)
self.detokenizer = VQGanDetokenizer()
vqgan_params = load_vqgan_torch_params('./pretrained/vqgan')
self.detokenizer.load_state_dict(vqgan_params)
def tokenize_text(self, text: str) -> numpy.ndarray:
print("tokenizing text")
tokens = self.tokenizer.tokenize(text)
print("text tokens", tokens)
text_token_count = self.config['max_text_length']
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens

@ -1,79 +1,58 @@
import jax
from jax import numpy as jnp
import numpy
from PIL import Image
import torch
from .min_dalle import MinDalle
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
def encode_flax(
text_tokens: numpy.ndarray,
config: dict,
params: dict
) -> jnp.ndarray:
print("loading flax encoder")
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = config['encoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['encoder_ffn_dim'],
text_token_count = config['max_text_length'],
text_vocab_count = config['encoder_vocab_size'],
layer_count = config['encoder_layers']
).bind({'params': params.pop('encoder')})
print("encoding text tokens")
encoder_state = encoder(text_tokens)
del encoder
return encoder_state
def decode_flax(
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
config: dict,
seed: int,
params: dict
) -> jnp.ndarray:
print("loading flax decoder")
decoder = DalleBartDecoderFlax(
image_token_count = config['image_length'],
text_token_count = config['max_text_length'],
image_vocab_count = config['image_vocab_size'],
attention_head_count = config['decoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
start_token = config['decoder_start_token_id']
)
print("sampling image tokens")
image_tokens = decoder.sample_image_tokens(
text_tokens,
encoder_state,
jax.random.PRNGKey(seed),
params.pop('decoder')
)
del decoder
return image_tokens
def generate_image_tokens_flax(
text_tokens: numpy.ndarray,
seed: int,
config: dict,
params: dict
) -> numpy.ndarray:
encoder_state = encode_flax(
text_tokens,
config,
params
)
image_tokens = decode_flax(
text_tokens,
encoder_state,
config,
seed,
params
)
image_tokens = numpy.array(image_tokens)
print("image tokens", list(image_tokens))
return image_tokens
class MinDalleFlax(MinDalle):
def __init__(self, is_mega: bool):
super().__init__(is_mega)
print("initializing MinDalleFlax")
print("loading encoder")
self.encoder = DalleBartEncoderFlax(
attention_head_count = self.config['encoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['encoder_ffn_dim'],
text_token_count = self.config['max_text_length'],
text_vocab_count = self.config['encoder_vocab_size'],
layer_count = self.config['encoder_layers']
).bind({'params': self.model_params.pop('encoder')})
print("loading decoder")
self.decoder = DalleBartDecoderFlax(
image_token_count = self.config['image_length'],
text_token_count = self.config['max_text_length'],
image_vocab_count = self.config['image_vocab_size'],
attention_head_count = self.config['decoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
start_token = self.config['decoder_start_token_id']
)
def generate_image(self, text: str, seed: int) -> Image.Image:
text_tokens = self.tokenize_text(text)
print("encoding text tokens")
encoder_state = self.encoder(text_tokens)
print("sampling image tokens")
image_tokens = self.decoder.sample_image_tokens(
text_tokens,
encoder_state,
jax.random.PRNGKey(seed),
self.model_params['decoder']
)
image_tokens = torch.tensor(numpy.array(image_tokens))
print("detokenizing image")
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
image = Image.fromarray(image.to('cpu').detach().numpy())
return image

@ -1,118 +1,83 @@
from random import sample
import numpy
import os
from PIL import Image
from typing import Dict
from torch import LongTensor, FloatTensor
from torch import LongTensor
import torch
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
from .models.vqgan_detokenizer import VQGanDetokenizer
from .load_params import convert_dalle_bart_torch_from_flax_params
from .min_dalle import MinDalle
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
from .load_params import (
load_vqgan_torch_params,
convert_dalle_bart_torch_from_flax_params
)
class MinDalleTorch(MinDalle):
def __init__(self, is_mega: bool, sample_token_count: int = 256):
super().__init__(is_mega)
print("initializing MinDalleTorch")
def encode_torch(
text_tokens: LongTensor,
config: dict,
params: dict
) -> FloatTensor:
print("loading torch encoder")
encoder = DalleBartEncoderTorch(
layer_count = config['encoder_layers'],
embed_count = config['d_model'],
attention_head_count = config['encoder_attention_heads'],
text_vocab_count = config['encoder_vocab_size'],
text_token_count = config['max_text_length'],
glu_embed_count = config['encoder_ffn_dim']
)
encoder_params = convert_dalle_bart_torch_from_flax_params(
params.pop('encoder'),
layer_count=config['encoder_layers'],
is_encoder=True
)
encoder.load_state_dict(encoder_params, strict=False)
del encoder_params
if torch.cuda.is_available(): encoder = encoder.cuda()
print("loading encoder")
self.encoder = DalleBartEncoderTorch(
layer_count = self.config['encoder_layers'],
embed_count = self.config['d_model'],
attention_head_count = self.config['encoder_attention_heads'],
text_vocab_count = self.config['encoder_vocab_size'],
text_token_count = self.config['max_text_length'],
glu_embed_count = self.config['encoder_ffn_dim']
)
encoder_params = convert_dalle_bart_torch_from_flax_params(
self.model_params.pop('encoder'),
layer_count=self.config['encoder_layers'],
is_encoder=True
)
self.encoder.load_state_dict(encoder_params, strict=False)
print("encoding text tokens")
encoder_state = encoder(text_tokens)
del encoder
return encoder_state
print("loading decoder")
self.decoder = DalleBartDecoderTorch(
image_vocab_size = self.config['image_vocab_size'],
image_token_count = self.config['image_length'],
sample_token_count = sample_token_count,
embed_count = self.config['d_model'],
attention_head_count = self.config['decoder_attention_heads'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
batch_count = 2,
start_token = self.config['decoder_start_token_id'],
is_verbose = True
)
decoder_params = convert_dalle_bart_torch_from_flax_params(
self.model_params.pop('decoder'),
layer_count=self.config['decoder_layers'],
is_encoder=False
)
self.decoder.load_state_dict(decoder_params, strict=False)
if torch.cuda.is_available():
self.encoder = self.encoder.cuda()
self.decoder = self.decoder.cuda()
self.detokenizer = self.detokenizer.cuda()
def decode_torch(
text_tokens: LongTensor,
encoder_state: FloatTensor,
config: dict,
seed: int,
params: dict,
image_token_count: int
) -> LongTensor:
print("loading torch decoder")
decoder = DalleBartDecoderTorch(
image_vocab_size = config['image_vocab_size'],
image_token_count = config['image_length'],
sample_token_count = image_token_count,
embed_count = config['d_model'],
attention_head_count = config['decoder_attention_heads'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
batch_count = 2,
start_token = config['decoder_start_token_id'],
is_verbose = True
)
decoder_params = convert_dalle_bart_torch_from_flax_params(
params.pop('decoder'),
layer_count=config['decoder_layers'],
is_encoder=False
)
decoder.load_state_dict(decoder_params, strict=False)
del decoder_params
if torch.cuda.is_available(): decoder = decoder.cuda()
print("sampling image tokens")
torch.manual_seed(seed)
image_tokens = decoder.forward(text_tokens, encoder_state)
return image_tokens
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
text_tokens = self.tokenize_text(text)
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
print("encoding text tokens")
encoder_state = self.encoder.forward(text_tokens)
def generate_image_tokens_torch(
text_tokens: numpy.ndarray,
seed: int,
config: dict,
params: dict,
image_token_count: int
) -> LongTensor:
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
encoder_state = encode_torch(
text_tokens,
config,
params
)
image_tokens = decode_torch(
text_tokens,
encoder_state,
config,
seed,
params,
image_token_count
)
return image_tokens
print("sampling image tokens")
torch.manual_seed(seed)
image_tokens = self.decoder.forward(text_tokens, encoder_state)
return image_tokens
def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray:
print("detokenizing image")
model_path = './pretrained/vqgan'
params = load_vqgan_torch_params(model_path)
detokenizer = VQGanDetokenizer()
detokenizer.load_state_dict(params)
if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda()
image = detokenizer.forward(image_tokens).to(torch.uint8)
del detokenizer, params
return image.to('cpu').detach().numpy()
def generate_image(self, text: str, seed: int) -> Image.Image:
image_tokens = self.generate_image_tokens(text, seed)
print("detokenizing image")
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
image = Image.fromarray(image.to('cpu').detach().numpy())
return image

@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax):
class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(self,
def __call__(
self,
decoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
def __call__(
self,
decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module):
self.final_ln = nn.LayerNorm(use_scale=False)
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
def __call__(self,
def __call__(
self,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module):
decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state
def sample_image_tokens(self,
def sample_image_tokens(
self,
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
prng_key: jax.random.PRNGKey,

@ -26,7 +26,8 @@ class DecoderCrossAttentionTorch(AttentionTorch):
class DecoderSelfAttentionTorch(AttentionTorch):
def forward(self,
def forward(
self,
decoder_state: FloatTensor,
keys_values: FloatTensor,
attention_mask: BoolTensor,
@ -49,7 +50,8 @@ class DecoderSelfAttentionTorch(AttentionTorch):
class DecoderLayerTorch(nn.Module):
def __init__(self,
def __init__(
self,
image_token_count: int,
head_count: int,
embed_count: int,
@ -69,7 +71,8 @@ class DecoderLayerTorch(nn.Module):
if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda()
def forward(self,
def forward(
self,
decoder_state: FloatTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
@ -111,7 +114,8 @@ class DecoderLayerTorch(nn.Module):
class DalleBartDecoderTorch(nn.Module):
def __init__(self,
def __init__(
self,
image_vocab_size: int,
image_token_count: int,
sample_token_count: int,
@ -158,7 +162,8 @@ class DalleBartDecoderTorch(nn.Module):
self.start_token = self.start_token.cuda()
def decode_step(self,
def decode_step(
self,
text_tokens: LongTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
@ -198,7 +203,8 @@ class DalleBartDecoderTorch(nn.Module):
return probs, keys_values
def forward(self,
def forward(
self,
text_tokens: LongTensor,
encoder_state: FloatTensor
) -> LongTensor:

@ -34,7 +34,8 @@ class AttentionFlax(nn.Module):
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
def forward(self,
def forward(
self,
keys: jnp.ndarray,
values: jnp.ndarray,
queries: jnp.ndarray,
@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
def __call__(
self,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:

@ -37,7 +37,8 @@ class AttentionTorch(nn.Module):
self.one = torch.ones((1, 1))
if torch.cuda.is_available(): self.one = self.one.cuda()
def forward(self,
def forward(
self,
keys: FloatTensor,
values: FloatTensor,
queries: FloatTensor,
@ -105,7 +106,8 @@ class EncoderLayerTorch(nn.Module):
class DalleBartEncoderTorch(nn.Module):
def __init__(self,
def __init__(
self,
layer_count: int,
embed_count: int,
attention_head_count: int,

@ -8,7 +8,7 @@ class TextTokenizer:
pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
def __call__(self, text: str) -> List[int]:
def tokenize(self, text: str) -> List[int]:
sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>']

Loading…
Cancel
Save