simplified

This commit is contained in:
Brett Kuprel 2022-06-27 15:46:04 -04:00
parent 18e6a9852f
commit e7001f063c
8 changed files with 182 additions and 344 deletions

View File

@ -1,4 +1,4 @@
# min DALL·E # min(DALL·E)
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini) in both Flax and PyTorch This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini) in both Flax and PyTorch
@ -11,21 +11,19 @@ Run `sh setup.sh` to install dependencies and download pretrained models. The o
Here are some examples Here are some examples
``` ```
python3 image_from_text_flax.py \ python3 image_from_text.py --text='alien life' --seed=7
--dalle_bart_path='./pretrained/dalle_bart_mega' \ ```
--vqgan_path='./pretrained/vqgan' \ ![Alien](examples/alien.png)
--image_path='./generated/avacado_armchair_flax.png' \
--seed=4 \ ```
python3 image_from_text.py --mega --seed=4 \
--text='a comfy chair that looks like an avocado' --text='a comfy chair that looks like an avocado'
``` ```
![Avocado Armchair](examples/avocado_armchair.png) ![Avocado Armchair](examples/avocado_armchair.png)
``` ```
python3 image_from_text_flax.py \ python3 image_from_text.py --mega --seed=100 \
--dalle_path='./pretrained/dalle-mega' \
--seed=100 \
--image_path='./generated/godzilla_trial.png' \
--text='court sketch of godzilla on trial' --text='court sketch of godzilla on trial'
``` ```

BIN
examples/alien.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

129
image_from_text.py Normal file
View File

@ -0,0 +1,129 @@
import argparse
import os
import json
import numpy
from PIL import Image
from typing import Tuple, List
from min_dalle.load_params import load_dalle_bart_flax_params
from min_dalle.text_tokenizer import TextTokenizer
from min_dalle.min_dalle_flax import generate_image_tokens_flax
from min_dalle.min_dalle_torch import (
generate_image_tokens_torch,
detokenize_torch
)
parser = argparse.ArgumentParser()
parser.add_argument(
'--text',
help='text to generate image from',
type=str
)
parser.add_argument(
'--seed',
help='random seed',
type=int,
default=0
)
parser.add_argument(
'--mega',
help='use larger dalle mega model',
action=argparse.BooleanOptionalAction
)
parser.add_argument(
'--torch',
help='use torch transformers',
action=argparse.BooleanOptionalAction
)
parser.add_argument(
'--image_path',
help='path to save generated image',
type=str,
default='generated.png'
)
parser.add_argument(
'--image_token_count',
help='number of image tokens to generate (for debugging)',
type=int,
default=256
)
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
print("loading model")
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
assert(os.path.exists(os.path.join(path, f)))
with open(path + '/config.json', 'r') as f:
config = json.load(f)
with open(path + '/vocab.json') as f:
vocab = json.load(f)
with open(path + '/merges.txt') as f:
merges = f.read().split("\n")[1:-1]
return config, vocab, merges
def ascii_from_image(image: Image.Image, size: int) -> str:
rgb_pixels = image.resize((size, int(0.55 * size))).convert('L').getdata()
chars = list('.,;/IOX')
chars = [chars[i * len(chars) // 256] for i in rgb_pixels]
chars = [chars[i * size: (i + 1) * size] for i in range(size // 2)]
return '\n'.join(''.join(row) for row in chars)
def save_image(image: numpy.ndarray, path: str) -> Image.Image:
if os.path.isdir(path):
path = os.path.join(path, 'generated.png')
elif not path.endswith('.png'):
path += '.png'
print("saving image to", path)
image: Image.Image = Image.fromarray(numpy.asarray(image))
image.save(path)
return image
def tokenize_text(
text: str,
config: dict,
vocab: dict,
merges: List[str]
) -> numpy.ndarray:
print("tokenizing text")
tokens = TextTokenizer(vocab, merges)(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens
if __name__ == '__main__':
args = parser.parse_args()
model_name = 'mega' if args.mega == True else 'mini'
model_path = './pretrained/dalle_bart_{}'.format(model_name)
config, vocab, merges = load_dalle_bart_metadata(model_path)
text_tokens = tokenize_text(args.text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(model_path)
image_tokens = numpy.zeros(config['image_length'])
if args.torch == True:
image_tokens[:args.image_token_count] = generate_image_tokens_torch(
text_tokens = text_tokens,
seed = args.seed,
config = config,
params = params_dalle_bart,
image_token_count = args.image_token_count
)
else:
image_tokens[...] = generate_image_tokens_flax(
text_tokens = text_tokens,
seed = args.seed,
config = config,
params = params_dalle_bart,
)
if args.image_token_count == config['image_length']:
image = detokenize_torch(image_tokens)
image = save_image(image, args.image_path)
print(ascii_from_image(image, size=128))

View File

@ -1,126 +0,0 @@
import jax
from jax import numpy as jnp
import numpy
import argparse
from min_dalle.load_params import load_dalle_bart_flax_params
from min_dalle.image_from_text import (
load_dalle_bart_metadata,
tokenize,
detokenize_torch,
save_image,
ascii_from_image
)
from min_dalle.models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from min_dalle.models.dalle_bart_decoder_flax import DalleBartDecoderFlax
parser = argparse.ArgumentParser()
parser.add_argument(
'--text',
help='text to generate image from',
type=str
)
parser.add_argument(
'--seed',
help='random seed',
type=int,
default=0
)
parser.add_argument(
'--image_path',
help='generated image path',
type=str,
default='generated.png'
)
parser.add_argument(
'--dalle_bart_path',
help='pretraied dalle bart path',
type=str,
default='./pretrained/dalle_bart_mini'
)
parser.add_argument(
'--vqgan_path',
help='pretraied vqgan path',
type=str,
default='./pretrained/vqgan'
)
def encode_flax(
text_tokens: numpy.ndarray,
config: dict,
params: dict
) -> jnp.ndarray:
print("loading flax encoder")
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = config['encoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['encoder_ffn_dim'],
text_token_count = config['max_text_length'],
text_vocab_count = config['encoder_vocab_size'],
layer_count = config['encoder_layers']
).bind({'params': params.pop('encoder')})
print("encoding text tokens")
encoder_state = encoder(text_tokens)
del encoder
return encoder_state
def decode_flax(
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
config: dict,
seed: int,
params: dict
) -> jnp.ndarray:
print("loading flax decoder")
decoder = DalleBartDecoderFlax(
image_token_count = config['image_length'],
text_token_count = config['max_text_length'],
image_vocab_count = config['image_vocab_size'],
attention_head_count = config['decoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
start_token = config['decoder_start_token_id']
)
print("sampling image tokens")
image_tokens = decoder.sample_image_tokens(
text_tokens,
encoder_state,
jax.random.PRNGKey(seed),
params.pop('decoder')
)
del decoder
return image_tokens
def generate_image_tokens_flax(
text: str,
seed: int,
dalle_bart_path: str
) -> numpy.ndarray:
config, vocab, merges = load_dalle_bart_metadata(dalle_bart_path)
text_tokens = tokenize(text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(dalle_bart_path)
encoder_state = encode_flax(text_tokens, config, params_dalle_bart)
image_tokens = decode_flax(
text_tokens,
encoder_state,
config, seed,
params_dalle_bart
)
return numpy.array(image_tokens)
if __name__ == '__main__':
args = parser.parse_args()
image_tokens = generate_image_tokens_flax(
args.text,
args.seed,
args.dalle_bart_path
)
print("image tokens", list(image_tokens))
image = detokenize_torch(image_tokens, args.vqgan_path)
image = save_image(image, args.image_path)
print(ascii_from_image(image, size=128))

View File

@ -1,70 +0,0 @@
import os
import json
import numpy
import torch
from PIL import Image
from typing import Tuple, List
from .text_tokenizer import TextTokenizer
from .models.vqgan_detokenizer import VQGanDetokenizer
from .load_params import load_vqgan_torch_params
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
print("loading model")
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
assert(os.path.exists(os.path.join(path, f)))
with open(path + '/config.json', 'r') as f:
config = json.load(f)
with open(path + '/vocab.json') as f:
vocab = json.load(f)
with open(path + '/merges.txt') as f:
merges = f.read().split("\n")[1:-1]
return config, vocab, merges
def ascii_from_image(image: Image.Image, size: int) -> str:
rgb_pixels = image.resize((size, int(0.55 * size))).convert('L').getdata()
chars = list('.,;/IOX')
chars = [chars[i * len(chars) // 256] for i in rgb_pixels]
chars = [chars[i * size: (i + 1) * size] for i in range(size // 2)]
return '\n'.join(''.join(row) for row in chars)
def save_image(image: numpy.ndarray, path: str) -> Image.Image:
if os.path.isdir(path):
path = os.path.join(path, 'generated.png')
elif not path.endswith('.png'):
path += '.png'
print("saving image to", path)
image: Image.Image = Image.fromarray(numpy.asarray(image))
image.save(path)
return image
def tokenize(
text: str,
config: dict,
vocab: dict,
merges: List[str]
) -> numpy.ndarray:
print("tokenizing text")
tokens = TextTokenizer(vocab, merges)(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens
def detokenize_torch(
image_tokens: numpy.ndarray,
model_path: str
) -> numpy.ndarray:
print("detokenizing image")
params = load_vqgan_torch_params(model_path)
detokenizer = VQGanDetokenizer()
detokenizer.load_state_dict(params)
image_tokens = torch.tensor(image_tokens).to(torch.long)
image = detokenizer.forward(image_tokens).to(torch.uint8)
return image.detach().numpy()

View File

@ -1,50 +1,9 @@
import jax import jax
from jax import numpy as jnp from jax import numpy as jnp
import numpy import numpy
import argparse
from load_params import load_dalle_bart_flax_params from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from image_from_text import ( from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
load_dalle_bart_metadata,
tokenize,
detokenize_torch,
save_image,
ascii_from_image
)
from models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from models.dalle_bart_decoder_flax import DalleBartDecoderFlax
parser = argparse.ArgumentParser()
parser.add_argument(
'--text',
help='text to generate image from',
type=str
)
parser.add_argument(
'--seed',
help='random seed',
type=int,
default=0
)
parser.add_argument(
'--image_path',
help='generated image path',
type=str,
default='generated.png'
)
parser.add_argument(
'--dalle_bart_path',
help='pretraied dalle bart path',
type=str,
default='./pretrained/dalle_bart_mini'
)
parser.add_argument(
'--vqgan_path',
help='pretraied vqgan path',
type=str,
default='./pretrained/vqgan'
)
def encode_flax( def encode_flax(
@ -67,6 +26,7 @@ def encode_flax(
del encoder del encoder
return encoder_state return encoder_state
def decode_flax( def decode_flax(
text_tokens: jnp.ndarray, text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
@ -95,32 +55,25 @@ def decode_flax(
del decoder del decoder
return image_tokens return image_tokens
def generate_image_tokens_flax( def generate_image_tokens_flax(
text: str, text_tokens: numpy.ndarray,
seed: int, seed: int,
dalle_bart_path: str config: dict,
params: dict
) -> numpy.ndarray: ) -> numpy.ndarray:
config, vocab, merges = load_dalle_bart_metadata(dalle_bart_path) encoder_state = encode_flax(
text_tokens = tokenize(text, config, vocab, merges) text_tokens,
params_dalle_bart = load_dalle_bart_flax_params(dalle_bart_path) config,
encoder_state = encode_flax(text_tokens, config, params_dalle_bart) params
)
image_tokens = decode_flax( image_tokens = decode_flax(
text_tokens, text_tokens,
encoder_state, encoder_state,
config, seed, config,
params_dalle_bart seed,
) params
return numpy.array(image_tokens)
if __name__ == '__main__':
args = parser.parse_args()
image_tokens = generate_image_tokens_flax(
args.text,
args.seed,
args.dalle_bart_path
) )
image_tokens = numpy.array(image_tokens)
print("image tokens", list(image_tokens)) print("image tokens", list(image_tokens))
image = detokenize_torch(image_tokens, args.vqgan_path) return image_tokens
image = save_image(image, args.image_path)
print(ascii_from_image(image, size=128))

View File

@ -1,61 +1,17 @@
import numpy import numpy
import torch import torch
from torch import Tensor from torch import Tensor
import argparse
from typing import Dict from typing import Dict
from min_dalle.image_from_text import ( from .models.vqgan_detokenizer import VQGanDetokenizer
load_dalle_bart_metadata, from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
tokenize, from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
detokenize_torch,
save_image,
ascii_from_image
)
from min_dalle.models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from min_dalle.models.dalle_bart_decoder_torch import DalleBartDecoderTorch
from min_dalle.load_params import ( from .load_params import (
load_dalle_bart_flax_params, load_vqgan_torch_params,
convert_dalle_bart_torch_from_flax_params convert_dalle_bart_torch_from_flax_params
) )
parser = argparse.ArgumentParser()
parser.add_argument(
'--text',
help='text to generate image from',
type=str
)
parser.add_argument(
'--seed',
help='random seed',
type=int,
default=0
)
parser.add_argument(
'--image_token_count',
help='image tokens to sample',
type=int,
default=256
)
parser.add_argument(
'--image_path',
help='generated image path',
type=str,
default='generated.png'
)
parser.add_argument(
'--dalle_bart_path',
help='pretraied dalle bart path',
type=str,
default='./pretrained/dalle_bart_mini'
)
parser.add_argument(
'--vqgan_path',
help='pretraied vqgan path',
type=str,
default='./pretrained/vqgan'
)
def encode_torch( def encode_torch(
text_tokens: numpy.ndarray, text_tokens: numpy.ndarray,
@ -123,37 +79,35 @@ def decode_torch(
def generate_image_tokens_torch( def generate_image_tokens_torch(
text: str, text_tokens: numpy.ndarray,
seed: int, seed: int,
image_token_count: int, config: dict,
dalle_bart_path: str params: dict,
image_token_count: int
) -> numpy.ndarray: ) -> numpy.ndarray:
config, vocab, merges = load_dalle_bart_metadata(dalle_bart_path) encoder_state = encode_torch(
text_tokens = tokenize(text, config, vocab, merges) text_tokens,
params_dalle_bart = load_dalle_bart_flax_params(dalle_bart_path) config,
encoder_state = encode_torch(text_tokens, config, params_dalle_bart) params
)
image_tokens = decode_torch( image_tokens = decode_torch(
text_tokens, text_tokens,
encoder_state, encoder_state,
config, seed, params_dalle_bart, config,
seed,
params,
image_token_count image_token_count
) )
return image_tokens.detach().numpy() return image_tokens.detach().numpy()
if __name__ == '__main__': def detokenize_torch(image_tokens: numpy.ndarray) -> numpy.ndarray:
args = parser.parse_args() print("detokenizing image")
image_tokens = generate_image_tokens_torch( model_path = './pretrained/vqgan'
args.text, params = load_vqgan_torch_params(model_path)
args.seed, detokenizer = VQGanDetokenizer()
args.image_token_count, detokenizer.load_state_dict(params)
args.dalle_bart_path image_tokens = torch.tensor(image_tokens).to(torch.long)
) image = detokenizer.forward(image_tokens).to(torch.uint8)
if args.image_token_count < 256: return image.detach().numpy()
print("image tokens", list(image_tokens, ))
else:
image = detokenize_torch(image_tokens, args.vqgan_path)
image = save_image(image, args.image_path)
print(ascii_from_image(image, size=128))