refactored to load models once and run multiple times
This commit is contained in:
parent
1ef9b0b929
commit
ed91ab4a30
10
README.md
10
README.md
|
@ -2,18 +2,18 @@
|
||||||
|
|
||||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
||||||
|
|
||||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model.
|
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy`, `torch`, and `flax`.
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually:
|
Run `sh setup.sh` to install dependencies and download pretrained models. The `wandb` python package is installed to download DALL·E mini and DALL·E mega. Alternatively, the models can be downloaded manually here:
|
||||||
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
||||||
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
||||||
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
Use the command line python script `image_from_text.py` to generate images. Here are some examples:
|
The simplest way to get started is the command line python script `image_from_text.py` provided. Here are some examples runs:
|
||||||
|
|
||||||
```
|
```
|
||||||
python image_from_text.py --text='alien life' --seed=7
|
python image_from_text.py --text='alien life' --seed=7
|
||||||
|
@ -32,3 +32,7 @@ python image_from_text.py --text='court sketch of godzilla on trial' --mega --se
|
||||||
```
|
```
|
||||||
|
|
||||||
![Godzilla Trial](examples/godzilla_trial.png)
|
![Godzilla Trial](examples/godzilla_trial.png)
|
||||||
|
|
||||||
|
### Load once run multiple times
|
||||||
|
|
||||||
|
The command line script loads the models and parameters each time. The colab notebook demonstrates how to load the models once and run multiple times.
|
|
@ -2,8 +2,8 @@ import argparse
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from min_dalle.generate_image import generate_image_from_text
|
from min_dalle.min_dalle_torch import MinDalleTorch
|
||||||
|
from min_dalle.min_dalle_flax import MinDalleFlax
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--mega', action='store_true')
|
parser.add_argument('--mega', action='store_true')
|
||||||
|
@ -15,7 +15,7 @@ parser.set_defaults(torch=False)
|
||||||
parser.add_argument('--text', type=str)
|
parser.add_argument('--text', type=str)
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--image_path', type=str, default='generated')
|
parser.add_argument('--image_path', type=str, default='generated')
|
||||||
parser.add_argument('--image_token_count', type=int, default=256) # for debugging
|
parser.add_argument('--sample_token_count', type=int, default=256) # for debugging
|
||||||
|
|
||||||
|
|
||||||
def ascii_from_image(image: Image.Image, size: int) -> str:
|
def ascii_from_image(image: Image.Image, size: int) -> str:
|
||||||
|
@ -36,19 +36,40 @@ def save_image(image: Image.Image, path: str):
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image(
|
||||||
|
is_torch: bool,
|
||||||
|
is_mega: bool,
|
||||||
|
text: str,
|
||||||
|
seed: int,
|
||||||
|
image_path: str,
|
||||||
|
sample_token_count: int
|
||||||
|
):
|
||||||
|
if is_torch:
|
||||||
|
image_generator = MinDalleTorch(is_mega, sample_token_count)
|
||||||
|
image_tokens = image_generator.generate_image_tokens(text, seed)
|
||||||
|
|
||||||
|
if sample_token_count < image_generator.config['image_length']:
|
||||||
|
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
image = image_generator.generate_image(text, seed)
|
||||||
|
|
||||||
|
else:
|
||||||
|
image_generator = MinDalleFlax(is_mega)
|
||||||
|
image = image_generator.generate_image(text, seed)
|
||||||
|
|
||||||
|
save_image(image, image_path)
|
||||||
|
print(ascii_from_image(image, size=128))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print(args)
|
print(args)
|
||||||
|
generate_image(
|
||||||
image = generate_image_from_text(
|
is_torch=args.torch,
|
||||||
text = args.text,
|
is_mega=args.mega,
|
||||||
is_mega = args.mega,
|
text=args.text,
|
||||||
is_torch = args.torch,
|
seed=args.seed,
|
||||||
seed = args.seed,
|
image_path=args.image_path,
|
||||||
image_token_count = args.image_token_count
|
sample_token_count=args.sample_token_count
|
||||||
)
|
)
|
||||||
|
|
||||||
if image != None:
|
|
||||||
save_image(image, args.image_path)
|
|
||||||
print(ascii_from_image(image, size=128))
|
|
|
@ -1,78 +0,0 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import numpy
|
|
||||||
from PIL import Image
|
|
||||||
from typing import Tuple, List
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from min_dalle.load_params import load_dalle_bart_flax_params
|
|
||||||
from min_dalle.text_tokenizer import TextTokenizer
|
|
||||||
from min_dalle.min_dalle_flax import generate_image_tokens_flax
|
|
||||||
from min_dalle.min_dalle_torch import (
|
|
||||||
generate_image_tokens_torch,
|
|
||||||
detokenize_torch
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
|
|
||||||
print("parsing metadata from {}".format(path))
|
|
||||||
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
|
|
||||||
assert(os.path.exists(os.path.join(path, f)))
|
|
||||||
with open(path + '/config.json', 'r') as f:
|
|
||||||
config = json.load(f)
|
|
||||||
with open(path + '/vocab.json') as f:
|
|
||||||
vocab = json.load(f)
|
|
||||||
with open(path + '/merges.txt') as f:
|
|
||||||
merges = f.read().split("\n")[1:-1]
|
|
||||||
return config, vocab, merges
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_text(
|
|
||||||
text: str,
|
|
||||||
config: dict,
|
|
||||||
vocab: dict,
|
|
||||||
merges: List[str]
|
|
||||||
) -> numpy.ndarray:
|
|
||||||
print("tokenizing text")
|
|
||||||
tokens = TextTokenizer(vocab, merges)(text)
|
|
||||||
print("text tokens", tokens)
|
|
||||||
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
|
|
||||||
text_tokens[0, :len(tokens)] = tokens
|
|
||||||
text_tokens[1, :2] = [tokens[0], tokens[-1]]
|
|
||||||
return text_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image_from_text(
|
|
||||||
text: str,
|
|
||||||
is_mega: bool = False,
|
|
||||||
is_torch: bool = False,
|
|
||||||
seed: int = 0,
|
|
||||||
image_token_count: int = 256
|
|
||||||
) -> Image.Image:
|
|
||||||
model_name = 'mega' if is_mega else 'mini'
|
|
||||||
model_path = './pretrained/dalle_bart_{}'.format(model_name)
|
|
||||||
config, vocab, merges = load_dalle_bart_metadata(model_path)
|
|
||||||
text_tokens = tokenize_text(text, config, vocab, merges)
|
|
||||||
params_dalle_bart = load_dalle_bart_flax_params(model_path)
|
|
||||||
|
|
||||||
if is_torch:
|
|
||||||
image_tokens = generate_image_tokens_torch(
|
|
||||||
text_tokens = text_tokens,
|
|
||||||
seed = seed,
|
|
||||||
config = config,
|
|
||||||
params = params_dalle_bart,
|
|
||||||
image_token_count = image_token_count
|
|
||||||
)
|
|
||||||
if image_token_count == config['image_length']:
|
|
||||||
image = detokenize_torch(image_tokens, is_torch=True)
|
|
||||||
return Image.fromarray(image)
|
|
||||||
else:
|
|
||||||
print(list(image_tokens.to('cpu').detach().numpy()))
|
|
||||||
else:
|
|
||||||
image_tokens = generate_image_tokens_flax(
|
|
||||||
text_tokens = text_tokens,
|
|
||||||
seed = seed,
|
|
||||||
config = config,
|
|
||||||
params = params_dalle_bart,
|
|
||||||
)
|
|
||||||
image = detokenize_torch(torch.tensor(image_tokens), is_torch=False)
|
|
||||||
return Image.fromarray(image)
|
|
38
min_dalle/min_dalle.py
Normal file
38
min_dalle/min_dalle.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
from .text_tokenizer import TextTokenizer
|
||||||
|
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
|
||||||
|
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||||
|
|
||||||
|
class MinDalle:
|
||||||
|
def __init__(self, is_mega: bool):
|
||||||
|
self.is_mega = is_mega
|
||||||
|
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
||||||
|
model_path = os.path.join('pretrained', model_name)
|
||||||
|
|
||||||
|
print("reading files from {}".format(model_path))
|
||||||
|
with open(os.path.join(model_path, 'config.json'), 'r') as f:
|
||||||
|
self.config = json.load(f)
|
||||||
|
with open(os.path.join(model_path, 'vocab.json'), 'r') as f:
|
||||||
|
vocab = json.load(f)
|
||||||
|
with open(os.path.join(model_path, 'merges.txt'), 'r') as f:
|
||||||
|
merges = f.read().split("\n")[1:-1]
|
||||||
|
self.model_params = load_dalle_bart_flax_params(model_path)
|
||||||
|
|
||||||
|
self.tokenizer = TextTokenizer(vocab, merges)
|
||||||
|
self.detokenizer = VQGanDetokenizer()
|
||||||
|
vqgan_params = load_vqgan_torch_params('./pretrained/vqgan')
|
||||||
|
self.detokenizer.load_state_dict(vqgan_params)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_text(self, text: str) -> numpy.ndarray:
|
||||||
|
print("tokenizing text")
|
||||||
|
tokens = self.tokenizer.tokenize(text)
|
||||||
|
print("text tokens", tokens)
|
||||||
|
text_token_count = self.config['max_text_length']
|
||||||
|
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
|
||||||
|
text_tokens[0, :len(tokens)] = tokens
|
||||||
|
text_tokens[1, :2] = [tokens[0], tokens[-1]]
|
||||||
|
return text_tokens
|
|
@ -1,79 +1,58 @@
|
||||||
import jax
|
import jax
|
||||||
from jax import numpy as jnp
|
|
||||||
import numpy
|
import numpy
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .min_dalle import MinDalle
|
||||||
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
||||||
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
||||||
|
|
||||||
|
|
||||||
def encode_flax(
|
class MinDalleFlax(MinDalle):
|
||||||
text_tokens: numpy.ndarray,
|
def __init__(self, is_mega: bool):
|
||||||
config: dict,
|
super().__init__(is_mega)
|
||||||
params: dict
|
print("initializing MinDalleFlax")
|
||||||
) -> jnp.ndarray:
|
|
||||||
print("loading flax encoder")
|
|
||||||
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
|
|
||||||
attention_head_count = config['encoder_attention_heads'],
|
|
||||||
embed_count = config['d_model'],
|
|
||||||
glu_embed_count = config['encoder_ffn_dim'],
|
|
||||||
text_token_count = config['max_text_length'],
|
|
||||||
text_vocab_count = config['encoder_vocab_size'],
|
|
||||||
layer_count = config['encoder_layers']
|
|
||||||
).bind({'params': params.pop('encoder')})
|
|
||||||
|
|
||||||
print("encoding text tokens")
|
print("loading encoder")
|
||||||
encoder_state = encoder(text_tokens)
|
self.encoder = DalleBartEncoderFlax(
|
||||||
del encoder
|
attention_head_count = self.config['encoder_attention_heads'],
|
||||||
return encoder_state
|
embed_count = self.config['d_model'],
|
||||||
|
glu_embed_count = self.config['encoder_ffn_dim'],
|
||||||
|
text_token_count = self.config['max_text_length'],
|
||||||
|
text_vocab_count = self.config['encoder_vocab_size'],
|
||||||
|
layer_count = self.config['encoder_layers']
|
||||||
|
).bind({'params': self.model_params.pop('encoder')})
|
||||||
|
|
||||||
|
print("loading decoder")
|
||||||
|
self.decoder = DalleBartDecoderFlax(
|
||||||
|
image_token_count = self.config['image_length'],
|
||||||
|
text_token_count = self.config['max_text_length'],
|
||||||
|
image_vocab_count = self.config['image_vocab_size'],
|
||||||
|
attention_head_count = self.config['decoder_attention_heads'],
|
||||||
|
embed_count = self.config['d_model'],
|
||||||
|
glu_embed_count = self.config['decoder_ffn_dim'],
|
||||||
|
layer_count = self.config['decoder_layers'],
|
||||||
|
start_token = self.config['decoder_start_token_id']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def decode_flax(
|
def generate_image(self, text: str, seed: int) -> Image.Image:
|
||||||
text_tokens: jnp.ndarray,
|
text_tokens = self.tokenize_text(text)
|
||||||
encoder_state: jnp.ndarray,
|
|
||||||
config: dict,
|
|
||||||
seed: int,
|
|
||||||
params: dict
|
|
||||||
) -> jnp.ndarray:
|
|
||||||
print("loading flax decoder")
|
|
||||||
decoder = DalleBartDecoderFlax(
|
|
||||||
image_token_count = config['image_length'],
|
|
||||||
text_token_count = config['max_text_length'],
|
|
||||||
image_vocab_count = config['image_vocab_size'],
|
|
||||||
attention_head_count = config['decoder_attention_heads'],
|
|
||||||
embed_count = config['d_model'],
|
|
||||||
glu_embed_count = config['decoder_ffn_dim'],
|
|
||||||
layer_count = config['decoder_layers'],
|
|
||||||
start_token = config['decoder_start_token_id']
|
|
||||||
)
|
|
||||||
print("sampling image tokens")
|
|
||||||
image_tokens = decoder.sample_image_tokens(
|
|
||||||
text_tokens,
|
|
||||||
encoder_state,
|
|
||||||
jax.random.PRNGKey(seed),
|
|
||||||
params.pop('decoder')
|
|
||||||
)
|
|
||||||
del decoder
|
|
||||||
return image_tokens
|
|
||||||
|
|
||||||
|
print("encoding text tokens")
|
||||||
|
encoder_state = self.encoder(text_tokens)
|
||||||
|
|
||||||
def generate_image_tokens_flax(
|
print("sampling image tokens")
|
||||||
text_tokens: numpy.ndarray,
|
image_tokens = self.decoder.sample_image_tokens(
|
||||||
seed: int,
|
text_tokens,
|
||||||
config: dict,
|
encoder_state,
|
||||||
params: dict
|
jax.random.PRNGKey(seed),
|
||||||
) -> numpy.ndarray:
|
self.model_params['decoder']
|
||||||
encoder_state = encode_flax(
|
)
|
||||||
text_tokens,
|
|
||||||
config,
|
image_tokens = torch.tensor(numpy.array(image_tokens))
|
||||||
params
|
|
||||||
)
|
print("detokenizing image")
|
||||||
image_tokens = decode_flax(
|
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
|
||||||
text_tokens,
|
image = Image.fromarray(image.to('cpu').detach().numpy())
|
||||||
encoder_state,
|
return image
|
||||||
config,
|
|
||||||
seed,
|
|
||||||
params
|
|
||||||
)
|
|
||||||
image_tokens = numpy.array(image_tokens)
|
|
||||||
print("image tokens", list(image_tokens))
|
|
||||||
return image_tokens
|
|
|
@ -1,118 +1,83 @@
|
||||||
|
from random import sample
|
||||||
import numpy
|
import numpy
|
||||||
import os
|
import os
|
||||||
|
from PIL import Image
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from torch import LongTensor, FloatTensor
|
from torch import LongTensor
|
||||||
import torch
|
import torch
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_num_threads(os.cpu_count())
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
|
||||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
from .load_params import convert_dalle_bart_torch_from_flax_params
|
||||||
|
from .min_dalle import MinDalle
|
||||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||||
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||||
|
|
||||||
from .load_params import (
|
|
||||||
load_vqgan_torch_params,
|
class MinDalleTorch(MinDalle):
|
||||||
convert_dalle_bart_torch_from_flax_params
|
def __init__(self, is_mega: bool, sample_token_count: int = 256):
|
||||||
)
|
super().__init__(is_mega)
|
||||||
|
print("initializing MinDalleTorch")
|
||||||
|
|
||||||
|
print("loading encoder")
|
||||||
|
self.encoder = DalleBartEncoderTorch(
|
||||||
|
layer_count = self.config['encoder_layers'],
|
||||||
|
embed_count = self.config['d_model'],
|
||||||
|
attention_head_count = self.config['encoder_attention_heads'],
|
||||||
|
text_vocab_count = self.config['encoder_vocab_size'],
|
||||||
|
text_token_count = self.config['max_text_length'],
|
||||||
|
glu_embed_count = self.config['encoder_ffn_dim']
|
||||||
|
)
|
||||||
|
encoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||||
|
self.model_params.pop('encoder'),
|
||||||
|
layer_count=self.config['encoder_layers'],
|
||||||
|
is_encoder=True
|
||||||
|
)
|
||||||
|
self.encoder.load_state_dict(encoder_params, strict=False)
|
||||||
|
|
||||||
|
print("loading decoder")
|
||||||
|
self.decoder = DalleBartDecoderTorch(
|
||||||
|
image_vocab_size = self.config['image_vocab_size'],
|
||||||
|
image_token_count = self.config['image_length'],
|
||||||
|
sample_token_count = sample_token_count,
|
||||||
|
embed_count = self.config['d_model'],
|
||||||
|
attention_head_count = self.config['decoder_attention_heads'],
|
||||||
|
glu_embed_count = self.config['decoder_ffn_dim'],
|
||||||
|
layer_count = self.config['decoder_layers'],
|
||||||
|
batch_count = 2,
|
||||||
|
start_token = self.config['decoder_start_token_id'],
|
||||||
|
is_verbose = True
|
||||||
|
)
|
||||||
|
decoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||||
|
self.model_params.pop('decoder'),
|
||||||
|
layer_count=self.config['decoder_layers'],
|
||||||
|
is_encoder=False
|
||||||
|
)
|
||||||
|
self.decoder.load_state_dict(decoder_params, strict=False)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.encoder = self.encoder.cuda()
|
||||||
|
self.decoder = self.decoder.cuda()
|
||||||
|
self.detokenizer = self.detokenizer.cuda()
|
||||||
|
|
||||||
|
|
||||||
def encode_torch(
|
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
||||||
text_tokens: LongTensor,
|
text_tokens = self.tokenize_text(text)
|
||||||
config: dict,
|
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||||
params: dict
|
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||||
) -> FloatTensor:
|
|
||||||
print("loading torch encoder")
|
|
||||||
encoder = DalleBartEncoderTorch(
|
|
||||||
layer_count = config['encoder_layers'],
|
|
||||||
embed_count = config['d_model'],
|
|
||||||
attention_head_count = config['encoder_attention_heads'],
|
|
||||||
text_vocab_count = config['encoder_vocab_size'],
|
|
||||||
text_token_count = config['max_text_length'],
|
|
||||||
glu_embed_count = config['encoder_ffn_dim']
|
|
||||||
)
|
|
||||||
encoder_params = convert_dalle_bart_torch_from_flax_params(
|
|
||||||
params.pop('encoder'),
|
|
||||||
layer_count=config['encoder_layers'],
|
|
||||||
is_encoder=True
|
|
||||||
)
|
|
||||||
encoder.load_state_dict(encoder_params, strict=False)
|
|
||||||
del encoder_params
|
|
||||||
if torch.cuda.is_available(): encoder = encoder.cuda()
|
|
||||||
|
|
||||||
print("encoding text tokens")
|
print("encoding text tokens")
|
||||||
encoder_state = encoder(text_tokens)
|
encoder_state = self.encoder.forward(text_tokens)
|
||||||
del encoder
|
|
||||||
return encoder_state
|
print("sampling image tokens")
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
image_tokens = self.decoder.forward(text_tokens, encoder_state)
|
||||||
|
return image_tokens
|
||||||
|
|
||||||
|
|
||||||
def decode_torch(
|
def generate_image(self, text: str, seed: int) -> Image.Image:
|
||||||
text_tokens: LongTensor,
|
image_tokens = self.generate_image_tokens(text, seed)
|
||||||
encoder_state: FloatTensor,
|
print("detokenizing image")
|
||||||
config: dict,
|
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
|
||||||
seed: int,
|
image = Image.fromarray(image.to('cpu').detach().numpy())
|
||||||
params: dict,
|
return image
|
||||||
image_token_count: int
|
|
||||||
) -> LongTensor:
|
|
||||||
print("loading torch decoder")
|
|
||||||
decoder = DalleBartDecoderTorch(
|
|
||||||
image_vocab_size = config['image_vocab_size'],
|
|
||||||
image_token_count = config['image_length'],
|
|
||||||
sample_token_count = image_token_count,
|
|
||||||
embed_count = config['d_model'],
|
|
||||||
attention_head_count = config['decoder_attention_heads'],
|
|
||||||
glu_embed_count = config['decoder_ffn_dim'],
|
|
||||||
layer_count = config['decoder_layers'],
|
|
||||||
batch_count = 2,
|
|
||||||
start_token = config['decoder_start_token_id'],
|
|
||||||
is_verbose = True
|
|
||||||
)
|
|
||||||
decoder_params = convert_dalle_bart_torch_from_flax_params(
|
|
||||||
params.pop('decoder'),
|
|
||||||
layer_count=config['decoder_layers'],
|
|
||||||
is_encoder=False
|
|
||||||
)
|
|
||||||
decoder.load_state_dict(decoder_params, strict=False)
|
|
||||||
del decoder_params
|
|
||||||
if torch.cuda.is_available(): decoder = decoder.cuda()
|
|
||||||
|
|
||||||
print("sampling image tokens")
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
image_tokens = decoder.forward(text_tokens, encoder_state)
|
|
||||||
return image_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image_tokens_torch(
|
|
||||||
text_tokens: numpy.ndarray,
|
|
||||||
seed: int,
|
|
||||||
config: dict,
|
|
||||||
params: dict,
|
|
||||||
image_token_count: int
|
|
||||||
) -> LongTensor:
|
|
||||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
|
||||||
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
|
||||||
encoder_state = encode_torch(
|
|
||||||
text_tokens,
|
|
||||||
config,
|
|
||||||
params
|
|
||||||
)
|
|
||||||
image_tokens = decode_torch(
|
|
||||||
text_tokens,
|
|
||||||
encoder_state,
|
|
||||||
config,
|
|
||||||
seed,
|
|
||||||
params,
|
|
||||||
image_token_count
|
|
||||||
)
|
|
||||||
return image_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray:
|
|
||||||
print("detokenizing image")
|
|
||||||
model_path = './pretrained/vqgan'
|
|
||||||
params = load_vqgan_torch_params(model_path)
|
|
||||||
detokenizer = VQGanDetokenizer()
|
|
||||||
detokenizer.load_state_dict(params)
|
|
||||||
if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda()
|
|
||||||
image = detokenizer.forward(image_tokens).to(torch.uint8)
|
|
||||||
del detokenizer, params
|
|
||||||
return image.to('cpu').detach().numpy()
|
|
|
@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax):
|
||||||
|
|
||||||
|
|
||||||
class DecoderSelfAttentionFlax(AttentionFlax):
|
class DecoderSelfAttentionFlax(AttentionFlax):
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
decoder_state: jnp.ndarray,
|
decoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
keys_state: jnp.ndarray,
|
||||||
values_state: jnp.ndarray,
|
values_state: jnp.ndarray,
|
||||||
|
@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
||||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
decoder_state: jnp.ndarray,
|
decoder_state: jnp.ndarray,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
keys_state: jnp.ndarray,
|
||||||
|
@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
self.final_ln = nn.LayerNorm(use_scale=False)
|
self.final_ln = nn.LayerNorm(use_scale=False)
|
||||||
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
|
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
|
||||||
|
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
keys_state: jnp.ndarray,
|
||||||
values_state: jnp.ndarray,
|
values_state: jnp.ndarray,
|
||||||
|
@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
decoder_state = self.lm_head(decoder_state)
|
decoder_state = self.lm_head(decoder_state)
|
||||||
return decoder_state, keys_state, values_state
|
return decoder_state, keys_state, values_state
|
||||||
|
|
||||||
def sample_image_tokens(self,
|
def sample_image_tokens(
|
||||||
|
self,
|
||||||
text_tokens: jnp.ndarray,
|
text_tokens: jnp.ndarray,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
prng_key: jax.random.PRNGKey,
|
prng_key: jax.random.PRNGKey,
|
||||||
|
|
|
@ -26,7 +26,8 @@ class DecoderCrossAttentionTorch(AttentionTorch):
|
||||||
|
|
||||||
|
|
||||||
class DecoderSelfAttentionTorch(AttentionTorch):
|
class DecoderSelfAttentionTorch(AttentionTorch):
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
decoder_state: FloatTensor,
|
decoder_state: FloatTensor,
|
||||||
keys_values: FloatTensor,
|
keys_values: FloatTensor,
|
||||||
attention_mask: BoolTensor,
|
attention_mask: BoolTensor,
|
||||||
|
@ -49,7 +50,8 @@ class DecoderSelfAttentionTorch(AttentionTorch):
|
||||||
|
|
||||||
|
|
||||||
class DecoderLayerTorch(nn.Module):
|
class DecoderLayerTorch(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
image_token_count: int,
|
image_token_count: int,
|
||||||
head_count: int,
|
head_count: int,
|
||||||
embed_count: int,
|
embed_count: int,
|
||||||
|
@ -69,7 +71,8 @@ class DecoderLayerTorch(nn.Module):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.token_indices = self.token_indices.cuda()
|
self.token_indices = self.token_indices.cuda()
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
decoder_state: FloatTensor,
|
decoder_state: FloatTensor,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
keys_values_state: FloatTensor,
|
keys_values_state: FloatTensor,
|
||||||
|
@ -111,7 +114,8 @@ class DecoderLayerTorch(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DalleBartDecoderTorch(nn.Module):
|
class DalleBartDecoderTorch(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
image_vocab_size: int,
|
image_vocab_size: int,
|
||||||
image_token_count: int,
|
image_token_count: int,
|
||||||
sample_token_count: int,
|
sample_token_count: int,
|
||||||
|
@ -158,7 +162,8 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
self.start_token = self.start_token.cuda()
|
self.start_token = self.start_token.cuda()
|
||||||
|
|
||||||
|
|
||||||
def decode_step(self,
|
def decode_step(
|
||||||
|
self,
|
||||||
text_tokens: LongTensor,
|
text_tokens: LongTensor,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
keys_values_state: FloatTensor,
|
keys_values_state: FloatTensor,
|
||||||
|
@ -198,7 +203,8 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
return probs, keys_values
|
return probs, keys_values
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
text_tokens: LongTensor,
|
text_tokens: LongTensor,
|
||||||
encoder_state: FloatTensor
|
encoder_state: FloatTensor
|
||||||
) -> LongTensor:
|
) -> LongTensor:
|
||||||
|
|
|
@ -34,7 +34,8 @@ class AttentionFlax(nn.Module):
|
||||||
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
|
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||||
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
|
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
keys: jnp.ndarray,
|
keys: jnp.ndarray,
|
||||||
values: jnp.ndarray,
|
values: jnp.ndarray,
|
||||||
queries: jnp.ndarray,
|
queries: jnp.ndarray,
|
||||||
|
@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
|
||||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self,
|
def __call__(
|
||||||
|
self,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
attention_mask: jnp.ndarray
|
attention_mask: jnp.ndarray
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
|
|
|
@ -37,7 +37,8 @@ class AttentionTorch(nn.Module):
|
||||||
self.one = torch.ones((1, 1))
|
self.one = torch.ones((1, 1))
|
||||||
if torch.cuda.is_available(): self.one = self.one.cuda()
|
if torch.cuda.is_available(): self.one = self.one.cuda()
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
|
self,
|
||||||
keys: FloatTensor,
|
keys: FloatTensor,
|
||||||
values: FloatTensor,
|
values: FloatTensor,
|
||||||
queries: FloatTensor,
|
queries: FloatTensor,
|
||||||
|
@ -105,7 +106,8 @@ class EncoderLayerTorch(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DalleBartEncoderTorch(nn.Module):
|
class DalleBartEncoderTorch(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
layer_count: int,
|
layer_count: int,
|
||||||
embed_count: int,
|
embed_count: int,
|
||||||
attention_head_count: int,
|
attention_head_count: int,
|
||||||
|
|
|
@ -8,7 +8,7 @@ class TextTokenizer:
|
||||||
pairs = [tuple(pair.split()) for pair in merges]
|
pairs = [tuple(pair.split()) for pair in merges]
|
||||||
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
|
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
|
||||||
|
|
||||||
def __call__(self, text: str) -> List[int]:
|
def tokenize(self, text: str) -> List[int]:
|
||||||
sep_token = self.token_from_subword['</s>']
|
sep_token = self.token_from_subword['</s>']
|
||||||
cls_token = self.token_from_subword['<s>']
|
cls_token = self.token_from_subword['<s>']
|
||||||
unk_token = self.token_from_subword['<unk>']
|
unk_token = self.token_from_subword['<unk>']
|
||||||
|
|
Loading…
Reference in New Issue
Block a user