|
|
|
@ -3,6 +3,7 @@ import json |
|
|
|
|
import numpy |
|
|
|
|
from PIL import Image |
|
|
|
|
from typing import Tuple, List |
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
from min_dalle.load_params import load_dalle_bart_flax_params |
|
|
|
|
from min_dalle.text_tokenizer import TextTokenizer |
|
|
|
@ -53,25 +54,23 @@ def generate_image_from_text( |
|
|
|
|
text_tokens = tokenize_text(text, config, vocab, merges) |
|
|
|
|
params_dalle_bart = load_dalle_bart_flax_params(model_path) |
|
|
|
|
|
|
|
|
|
image_tokens = numpy.zeros(config['image_length']) |
|
|
|
|
if is_torch: |
|
|
|
|
image_tokens[:image_token_count] = generate_image_tokens_torch( |
|
|
|
|
image_tokens = generate_image_tokens_torch( |
|
|
|
|
text_tokens = text_tokens, |
|
|
|
|
seed = seed, |
|
|
|
|
config = config, |
|
|
|
|
params = params_dalle_bart, |
|
|
|
|
image_token_count = image_token_count |
|
|
|
|
) |
|
|
|
|
if image_token_count == config['image_length']: |
|
|
|
|
image = detokenize_torch(image_tokens) |
|
|
|
|
return Image.fromarray(image) |
|
|
|
|
else: |
|
|
|
|
image_tokens[...] = generate_image_tokens_flax( |
|
|
|
|
image_tokens = generate_image_tokens_flax( |
|
|
|
|
text_tokens = text_tokens, |
|
|
|
|
seed = seed, |
|
|
|
|
config = config, |
|
|
|
|
params = params_dalle_bart, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if image_token_count == config['image_length']: |
|
|
|
|
image = detokenize_torch(image_tokens) |
|
|
|
|
return Image.fromarray(image) |
|
|
|
|
else: |
|
|
|
|
return None |
|
|
|
|
image = detokenize_torch(torch.tensor(image_tokens)) |
|
|
|
|
return Image.fromarray(image) |