fixed an issue with argument parser

This commit is contained in:
Brett Kuprel 2022-06-27 16:49:42 -04:00
parent e7001f063c
commit a014dccc05
3 changed files with 105 additions and 104 deletions

View File

@ -4,27 +4,26 @@ This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma
### Setup ### Setup
Run `sh setup.sh` to install dependencies and download pretrained models. The only required dependencies are `flax` and `torch`. In the bash script, GitHub LFS is used to download the VQGan detokenizer and the Weight & Biases python package is used to download the DALL·E Mini and DALL·E Mega transformer models. You can also download those files manually by visting the links in the bash script. Run `sh setup.sh` to install dependencies and download pretrained models. The only required dependencies are `flax` and `torch`. In the bash script, GitHub LFS is used to download the VQGan detokenizer and the Weight & Biases python package is used to download the DALL·E Mini and DALL·E Mega transformer models. Those files can also be downloaded manually by visting the links in the bash script.
### Run ### Run
Here are some examples Here are some examples
``` ```
python3 image_from_text.py --text='alien life' --seed=7 python image_from_text.py --seed=7 --text='alien life'
``` ```
![Alien](examples/alien.png) ![Alien](examples/alien.png)
``` ```
python3 image_from_text.py --mega --seed=4 \ python image_from_text.py --mega --seed=4 --text='a comfy chair that looks like an avocado'
--text='a comfy chair that looks like an avocado'
``` ```
![Avocado Armchair](examples/avocado_armchair.png) ![Avocado Armchair](examples/avocado_armchair.png)
``` ```
python3 image_from_text.py --mega --seed=100 \ python image_from_text.py --mega --seed=100 --text='court sketch of godzilla on trial'
--text='court sketch of godzilla on trial'
``` ```
![Godzilla Trial](examples/godzilla_trial.png) ![Godzilla Trial](examples/godzilla_trial.png)

View File

@ -1,66 +1,21 @@
import argparse import argparse
import os import os
import json
import numpy
from PIL import Image from PIL import Image
from typing import Tuple, List
from min_dalle.load_params import load_dalle_bart_flax_params from min_dalle.generate_image import generate_image_from_text
from min_dalle.text_tokenizer import TextTokenizer
from min_dalle.min_dalle_flax import generate_image_tokens_flax
from min_dalle.min_dalle_torch import (
generate_image_tokens_torch,
detokenize_torch
)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument('--mega', action='store_true')
'--text', parser.add_argument('--no-mega', dest='mega', action='store_false')
help='text to generate image from', parser.set_defaults(mega=False)
type=str parser.add_argument('--torch', action='store_true')
) parser.add_argument('--no-torch', dest='torch', action='store_false')
parser.add_argument( parser.set_defaults(torch=False)
'--seed', parser.add_argument('--text', type=str)
help='random seed', parser.add_argument('--seed', type=int, default=0)
type=int, parser.add_argument('--image_path', type=str, default='generated')
default=0 parser.add_argument('--image_token_count', type=int, default=256) # for debugging
)
parser.add_argument(
'--mega',
help='use larger dalle mega model',
action=argparse.BooleanOptionalAction
)
parser.add_argument(
'--torch',
help='use torch transformers',
action=argparse.BooleanOptionalAction
)
parser.add_argument(
'--image_path',
help='path to save generated image',
type=str,
default='generated.png'
)
parser.add_argument(
'--image_token_count',
help='number of image tokens to generate (for debugging)',
type=int,
default=256
)
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
print("loading model")
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
assert(os.path.exists(os.path.join(path, f)))
with open(path + '/config.json', 'r') as f:
config = json.load(f)
with open(path + '/vocab.json') as f:
vocab = json.load(f)
with open(path + '/merges.txt') as f:
merges = f.read().split("\n")[1:-1]
return config, vocab, merges
def ascii_from_image(image: Image.Image, size: int) -> str: def ascii_from_image(image: Image.Image, size: int) -> str:
@ -71,59 +26,29 @@ def ascii_from_image(image: Image.Image, size: int) -> str:
return '\n'.join(''.join(row) for row in chars) return '\n'.join(''.join(row) for row in chars)
def save_image(image: numpy.ndarray, path: str) -> Image.Image: def save_image(image: Image.Image, path: str):
if os.path.isdir(path): if os.path.isdir(path):
path = os.path.join(path, 'generated.png') path = os.path.join(path, 'generated.png')
elif not path.endswith('.png'): elif not path.endswith('.png'):
path += '.png' path += '.png'
print("saving image to", path) print("saving image to", path)
image: Image.Image = Image.fromarray(numpy.asarray(image))
image.save(path) image.save(path)
return image return image
def tokenize_text(
text: str,
config: dict,
vocab: dict,
merges: List[str]
) -> numpy.ndarray:
print("tokenizing text")
tokens = TextTokenizer(vocab, merges)(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
model_name = 'mega' if args.mega == True else 'mini' print(args)
model_path = './pretrained/dalle_bart_{}'.format(model_name)
config, vocab, merges = load_dalle_bart_metadata(model_path)
text_tokens = tokenize_text(args.text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(model_path)
image_tokens = numpy.zeros(config['image_length']) image = generate_image_from_text(
if args.torch == True: text = args.text,
image_tokens[:args.image_token_count] = generate_image_tokens_torch( is_mega = args.mega,
text_tokens = text_tokens, is_torch = args.torch,
seed = args.seed, seed = args.seed,
config = config, image_token_count = args.image_token_count
params = params_dalle_bart, )
image_token_count = args.image_token_count
)
else:
image_tokens[...] = generate_image_tokens_flax(
text_tokens = text_tokens,
seed = args.seed,
config = config,
params = params_dalle_bart,
)
if args.image_token_count == config['image_length']: if image != None:
image = detokenize_torch(image_tokens) save_image(image, args.image_path)
image = save_image(image, args.image_path)
print(ascii_from_image(image, size=128)) print(ascii_from_image(image, size=128))

View File

@ -0,0 +1,77 @@
import os
import json
import numpy
from PIL import Image
from typing import Tuple, List
from min_dalle.load_params import load_dalle_bart_flax_params
from min_dalle.text_tokenizer import TextTokenizer
from min_dalle.min_dalle_flax import generate_image_tokens_flax
from min_dalle.min_dalle_torch import (
generate_image_tokens_torch,
detokenize_torch
)
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
print("parsing metadata from {}".format(path))
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
assert(os.path.exists(os.path.join(path, f)))
with open(path + '/config.json', 'r') as f:
config = json.load(f)
with open(path + '/vocab.json') as f:
vocab = json.load(f)
with open(path + '/merges.txt') as f:
merges = f.read().split("\n")[1:-1]
return config, vocab, merges
def tokenize_text(
text: str,
config: dict,
vocab: dict,
merges: List[str]
) -> numpy.ndarray:
print("tokenizing text")
tokens = TextTokenizer(vocab, merges)(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens
def generate_image_from_text(
text: str,
is_mega: bool = False,
is_torch: bool = False,
seed: int = 0,
image_token_count: int = 256
) -> Image.Image:
model_name = 'mega' if is_mega else 'mini'
model_path = './pretrained/dalle_bart_{}'.format(model_name)
config, vocab, merges = load_dalle_bart_metadata(model_path)
text_tokens = tokenize_text(text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(model_path)
image_tokens = numpy.zeros(config['image_length'])
if is_torch:
image_tokens[:image_token_count] = generate_image_tokens_torch(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
image_token_count = image_token_count
)
else:
image_tokens[...] = generate_image_tokens_flax(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
)
if image_token_count == config['image_length']:
image = detokenize_torch(image_tokens)
return Image.fromarray(image)
else:
return None