fixed relative imports

This commit is contained in:
Brett Kuprel 2022-06-27 12:43:47 -04:00
parent 9cf9d46dd7
commit 018414a5c3
12 changed files with 137 additions and 11 deletions

View File

@ -3,16 +3,16 @@ from jax import numpy as jnp
import numpy import numpy
import argparse import argparse
from load_params import load_dalle_bart_flax_params from min_dalle.load_params import load_dalle_bart_flax_params
from image_from_text import ( from min_dalle.image_from_text import (
load_dalle_bart_metadata, load_dalle_bart_metadata,
tokenize, tokenize,
detokenize_torch, detokenize_torch,
save_image, save_image,
ascii_from_image ascii_from_image
) )
from models.dalle_bart_encoder_flax import DalleBartEncoderFlax from min_dalle.models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from models.dalle_bart_decoder_flax import DalleBartDecoderFlax from min_dalle.models.dalle_bart_decoder_flax import DalleBartDecoderFlax
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -4,17 +4,17 @@ from torch import Tensor
import argparse import argparse
from typing import Dict from typing import Dict
from image_from_text import ( from min_dalle.image_from_text import (
load_dalle_bart_metadata, load_dalle_bart_metadata,
tokenize, tokenize,
detokenize_torch, detokenize_torch,
save_image, save_image,
ascii_from_image ascii_from_image
) )
from models.dalle_bart_encoder_torch import DalleBartEncoderTorch from min_dalle.models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from models.dalle_bart_decoder_torch import DalleBartDecoderTorch from min_dalle.models.dalle_bart_decoder_torch import DalleBartDecoderTorch
from load_params import ( from min_dalle.load_params import (
load_dalle_bart_flax_params, load_dalle_bart_flax_params,
convert_dalle_bart_torch_from_flax_params convert_dalle_bart_torch_from_flax_params
) )

View File

@ -5,9 +5,9 @@ import torch
from PIL import Image from PIL import Image
from typing import Tuple, List from typing import Tuple, List
from text_tokenizer import TextTokenizer from .text_tokenizer import TextTokenizer
from models.vqgan_detokenizer import VQGanDetokenizer from .models.vqgan_detokenizer import VQGanDetokenizer
from load_params import load_vqgan_torch_params from .load_params import load_vqgan_torch_params
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]: def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:

View File

@ -0,0 +1,126 @@
import jax
from jax import numpy as jnp
import numpy
import argparse
from load_params import load_dalle_bart_flax_params
from image_from_text import (
load_dalle_bart_metadata,
tokenize,
detokenize_torch,
save_image,
ascii_from_image
)
from models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from models.dalle_bart_decoder_flax import DalleBartDecoderFlax
parser = argparse.ArgumentParser()
parser.add_argument(
'--text',
help='text to generate image from',
type=str
)
parser.add_argument(
'--seed',
help='random seed',
type=int,
default=0
)
parser.add_argument(
'--image_path',
help='generated image path',
type=str,
default='generated.png'
)
parser.add_argument(
'--dalle_bart_path',
help='pretraied dalle bart path',
type=str,
default='./pretrained/dalle_bart_mini'
)
parser.add_argument(
'--vqgan_path',
help='pretraied vqgan path',
type=str,
default='./pretrained/vqgan'
)
def encode_flax(
text_tokens: numpy.ndarray,
config: dict,
params: dict
) -> jnp.ndarray:
print("loading flax encoder")
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = config['encoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['encoder_ffn_dim'],
text_token_count = config['max_text_length'],
text_vocab_count = config['encoder_vocab_size'],
layer_count = config['encoder_layers']
).bind({'params': params.pop('encoder')})
print("encoding text tokens")
encoder_state = encoder(text_tokens)
del encoder
return encoder_state
def decode_flax(
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
config: dict,
seed: int,
params: dict
) -> jnp.ndarray:
print("loading flax decoder")
decoder = DalleBartDecoderFlax(
image_token_count = config['image_length'],
text_token_count = config['max_text_length'],
image_vocab_count = config['image_vocab_size'],
attention_head_count = config['decoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
start_token = config['decoder_start_token_id']
)
print("sampling image tokens")
image_tokens = decoder.sample_image_tokens(
text_tokens,
encoder_state,
jax.random.PRNGKey(seed),
params.pop('decoder')
)
del decoder
return image_tokens
def generate_image_tokens_flax(
text: str,
seed: int,
dalle_bart_path: str
) -> numpy.ndarray:
config, vocab, merges = load_dalle_bart_metadata(dalle_bart_path)
text_tokens = tokenize(text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(dalle_bart_path)
encoder_state = encode_flax(text_tokens, config, params_dalle_bart)
image_tokens = decode_flax(
text_tokens,
encoder_state,
config, seed,
params_dalle_bart
)
return numpy.array(image_tokens)
if __name__ == '__main__':
args = parser.parse_args()
image_tokens = generate_image_tokens_flax(
args.text,
args.seed,
args.dalle_bart_path
)
print("image tokens", list(image_tokens))
image = detokenize_torch(image_tokens, args.vqgan_path)
image = save_image(image, args.image_path)
print(ascii_from_image(image, size=128))

View File