From a014dccc0596368ea9c10a115ed2536890865424 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Mon, 27 Jun 2022 16:49:42 -0400 Subject: [PATCH] fixed an issue with argument parser --- README.md | 11 ++-- image_from_text.py | 123 +++++++----------------------------- min_dalle/generate_image.py | 77 ++++++++++++++++++++++ 3 files changed, 106 insertions(+), 105 deletions(-) create mode 100644 min_dalle/generate_image.py diff --git a/README.md b/README.md index eef8bc5..5b4df29 100644 --- a/README.md +++ b/README.md @@ -4,27 +4,26 @@ This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma ### Setup -Run `sh setup.sh` to install dependencies and download pretrained models. The only required dependencies are `flax` and `torch`. In the bash script, GitHub LFS is used to download the VQGan detokenizer and the Weight & Biases python package is used to download the DALL·E Mini and DALL·E Mega transformer models. You can also download those files manually by visting the links in the bash script. +Run `sh setup.sh` to install dependencies and download pretrained models. The only required dependencies are `flax` and `torch`. In the bash script, GitHub LFS is used to download the VQGan detokenizer and the Weight & Biases python package is used to download the DALL·E Mini and DALL·E Mega transformer models. Those files can also be downloaded manually by visting the links in the bash script. ### Run Here are some examples ``` -python3 image_from_text.py --text='alien life' --seed=7 +python image_from_text.py --seed=7 --text='alien life' ``` ![Alien](examples/alien.png) + ``` -python3 image_from_text.py --mega --seed=4 \ - --text='a comfy chair that looks like an avocado' +python image_from_text.py --mega --seed=4 --text='a comfy chair that looks like an avocado' ``` ![Avocado Armchair](examples/avocado_armchair.png) ``` -python3 image_from_text.py --mega --seed=100 \ - --text='court sketch of godzilla on trial' +python image_from_text.py --mega --seed=100 --text='court sketch of godzilla on trial' ``` ![Godzilla Trial](examples/godzilla_trial.png) \ No newline at end of file diff --git a/image_from_text.py b/image_from_text.py index f11d0fb..a56522d 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -1,66 +1,21 @@ import argparse import os -import json -import numpy from PIL import Image -from typing import Tuple, List -from min_dalle.load_params import load_dalle_bart_flax_params -from min_dalle.text_tokenizer import TextTokenizer -from min_dalle.min_dalle_flax import generate_image_tokens_flax -from min_dalle.min_dalle_torch import ( - generate_image_tokens_torch, - detokenize_torch -) +from min_dalle.generate_image import generate_image_from_text parser = argparse.ArgumentParser() -parser.add_argument( - '--text', - help='text to generate image from', - type=str -) -parser.add_argument( - '--seed', - help='random seed', - type=int, - default=0 -) -parser.add_argument( - '--mega', - help='use larger dalle mega model', - action=argparse.BooleanOptionalAction -) -parser.add_argument( - '--torch', - help='use torch transformers', - action=argparse.BooleanOptionalAction -) -parser.add_argument( - '--image_path', - help='path to save generated image', - type=str, - default='generated.png' -) -parser.add_argument( - '--image_token_count', - help='number of image tokens to generate (for debugging)', - type=int, - default=256 -) - - -def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]: - print("loading model") - for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']: - assert(os.path.exists(os.path.join(path, f))) - with open(path + '/config.json', 'r') as f: - config = json.load(f) - with open(path + '/vocab.json') as f: - vocab = json.load(f) - with open(path + '/merges.txt') as f: - merges = f.read().split("\n")[1:-1] - return config, vocab, merges +parser.add_argument('--mega', action='store_true') +parser.add_argument('--no-mega', dest='mega', action='store_false') +parser.set_defaults(mega=False) +parser.add_argument('--torch', action='store_true') +parser.add_argument('--no-torch', dest='torch', action='store_false') +parser.set_defaults(torch=False) +parser.add_argument('--text', type=str) +parser.add_argument('--seed', type=int, default=0) +parser.add_argument('--image_path', type=str, default='generated') +parser.add_argument('--image_token_count', type=int, default=256) # for debugging def ascii_from_image(image: Image.Image, size: int) -> str: @@ -71,59 +26,29 @@ def ascii_from_image(image: Image.Image, size: int) -> str: return '\n'.join(''.join(row) for row in chars) -def save_image(image: numpy.ndarray, path: str) -> Image.Image: +def save_image(image: Image.Image, path: str): if os.path.isdir(path): path = os.path.join(path, 'generated.png') elif not path.endswith('.png'): path += '.png' print("saving image to", path) - image: Image.Image = Image.fromarray(numpy.asarray(image)) image.save(path) return image -def tokenize_text( - text: str, - config: dict, - vocab: dict, - merges: List[str] -) -> numpy.ndarray: - print("tokenizing text") - tokens = TextTokenizer(vocab, merges)(text) - print("text tokens", tokens) - text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32) - text_tokens[0, :len(tokens)] = tokens - text_tokens[1, :2] = [tokens[0], tokens[-1]] - return text_tokens - - if __name__ == '__main__': args = parser.parse_args() - model_name = 'mega' if args.mega == True else 'mini' - model_path = './pretrained/dalle_bart_{}'.format(model_name) - config, vocab, merges = load_dalle_bart_metadata(model_path) - text_tokens = tokenize_text(args.text, config, vocab, merges) - params_dalle_bart = load_dalle_bart_flax_params(model_path) - - image_tokens = numpy.zeros(config['image_length']) - if args.torch == True: - image_tokens[:args.image_token_count] = generate_image_tokens_torch( - text_tokens = text_tokens, - seed = args.seed, - config = config, - params = params_dalle_bart, - image_token_count = args.image_token_count - ) - else: - image_tokens[...] = generate_image_tokens_flax( - text_tokens = text_tokens, - seed = args.seed, - config = config, - params = params_dalle_bart, - ) - - if args.image_token_count == config['image_length']: - image = detokenize_torch(image_tokens) - image = save_image(image, args.image_path) + print(args) + + image = generate_image_from_text( + text = args.text, + is_mega = args.mega, + is_torch = args.torch, + seed = args.seed, + image_token_count = args.image_token_count + ) + + if image != None: + save_image(image, args.image_path) print(ascii_from_image(image, size=128)) \ No newline at end of file diff --git a/min_dalle/generate_image.py b/min_dalle/generate_image.py new file mode 100644 index 0000000..401f9e2 --- /dev/null +++ b/min_dalle/generate_image.py @@ -0,0 +1,77 @@ +import os +import json +import numpy +from PIL import Image +from typing import Tuple, List + +from min_dalle.load_params import load_dalle_bart_flax_params +from min_dalle.text_tokenizer import TextTokenizer +from min_dalle.min_dalle_flax import generate_image_tokens_flax +from min_dalle.min_dalle_torch import ( + generate_image_tokens_torch, + detokenize_torch +) + +def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]: + print("parsing metadata from {}".format(path)) + for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']: + assert(os.path.exists(os.path.join(path, f))) + with open(path + '/config.json', 'r') as f: + config = json.load(f) + with open(path + '/vocab.json') as f: + vocab = json.load(f) + with open(path + '/merges.txt') as f: + merges = f.read().split("\n")[1:-1] + return config, vocab, merges + + +def tokenize_text( + text: str, + config: dict, + vocab: dict, + merges: List[str] +) -> numpy.ndarray: + print("tokenizing text") + tokens = TextTokenizer(vocab, merges)(text) + print("text tokens", tokens) + text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32) + text_tokens[0, :len(tokens)] = tokens + text_tokens[1, :2] = [tokens[0], tokens[-1]] + return text_tokens + + +def generate_image_from_text( + text: str, + is_mega: bool = False, + is_torch: bool = False, + seed: int = 0, + image_token_count: int = 256 +) -> Image.Image: + model_name = 'mega' if is_mega else 'mini' + model_path = './pretrained/dalle_bart_{}'.format(model_name) + config, vocab, merges = load_dalle_bart_metadata(model_path) + text_tokens = tokenize_text(text, config, vocab, merges) + params_dalle_bart = load_dalle_bart_flax_params(model_path) + + image_tokens = numpy.zeros(config['image_length']) + if is_torch: + image_tokens[:image_token_count] = generate_image_tokens_torch( + text_tokens = text_tokens, + seed = seed, + config = config, + params = params_dalle_bart, + image_token_count = image_token_count + ) + else: + image_tokens[...] = generate_image_tokens_flax( + text_tokens = text_tokens, + seed = seed, + config = config, + params = params_dalle_bart, + ) + + if image_token_count == config['image_length']: + image = detokenize_torch(image_tokens) + return Image.fromarray(image) + else: + return None \ No newline at end of file