remove config.json dependency, default to torch in image_from_text.py

main
Brett Kuprel 2 years ago
parent 4404e70764
commit 85f5866eff
  1. 2
      .gitignore
  2. 6
      README.md
  3. 8
      image_from_text.py
  4. 25
      min_dalle.ipynb
  5. 6
      min_dalle/min_dalle_base.py
  6. 27
      min_dalle/min_dalle_flax.py
  7. 29
      min_dalle/min_dalle_torch.py
  8. 1
      min_dalle/models/dalle_bart_decoder_flax.py
  9. 10
      min_dalle/models/dalle_bart_decoder_torch.py
  10. 2
      setup_torch.sh

2
.gitignore vendored

@ -2,7 +2,7 @@
**/.cache/ **/.cache/
**/*.pkl **/*.pkl
**/.DS* **/.DS*
*.pt **/*.pt
*.mlpackage *.mlpackage
**/*.ckpt **/*.ckpt
.vscode .vscode

6
README.md vendored

@ -22,19 +22,19 @@ Use the python script `image_from_text.py` to generate images from the command l
### Examples ### Examples
``` ```
python image_from_text.py --text='artificial intelligence' --torch python image_from_text.py --text='artificial intelligence' --seed=7
``` ```
![Alien](examples/artificial_intelligence.png) ![Alien](examples/artificial_intelligence.png)
``` ```
python image_from_text.py --text='a comfy chair that looks like an avocado' --torch --mega --seed=10 python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=10
``` ```
![Avocado Armchair](examples/avocado_armchair.png) ![Avocado Armchair](examples/avocado_armchair.png)
``` ```
python image_from_text.py --text='court sketch of godzilla on trial' --torch --mega --seed=40 python image_from_text.py --text='court sketch of godzilla on trial' --mega --seed=40
``` ```
![Godzilla Trial](examples/godzilla_trial.png) ![Godzilla Trial](examples/godzilla_trial.png)

@ -11,9 +11,9 @@ parser.add_argument('--no-mega', dest='mega', action='store_false')
parser.set_defaults(mega=False) parser.set_defaults(mega=False)
parser.add_argument('--torch', action='store_true') parser.add_argument('--torch', action='store_true')
parser.add_argument('--no-torch', dest='torch', action='store_false') parser.add_argument('--no-torch', dest='torch', action='store_false')
parser.set_defaults(torch=False) parser.set_defaults(torch=True)
parser.add_argument('--text', type=str, default='alien life') parser.add_argument('--text', type=str, default='cat')
parser.add_argument('--seed', type=int, default=7) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--image_path', type=str, default='generated') parser.add_argument('--image_path', type=str, default='generated')
parser.add_argument('--token_count', type=int, default=256) # for debugging parser.add_argument('--token_count', type=int, default=256) # for debugging
@ -48,7 +48,7 @@ def generate_image(
if is_torch: if is_torch:
image_generator = MinDalleTorch(is_mega, is_reusable, token_count) image_generator = MinDalleTorch(is_mega, is_reusable, token_count)
if token_count < image_generator.config['image_length']: if token_count < 256:
image_tokens = image_generator.generate_image_tokens(text, seed) image_tokens = image_generator.generate_image_tokens(text, seed)
print('image tokens', list(image_tokens.to('cpu').detach().numpy())) print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
return return

25
min_dalle.ipynb vendored

File diff suppressed because one or more lines are too long

@ -11,12 +11,9 @@ class MinDalleBase:
self.model_path = os.path.join('pretrained', model_name) self.model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(self.model_path)) print("reading files from {}".format(self.model_path))
config_path = os.path.join(self.model_path, 'config.json')
vocab_path = os.path.join(self.model_path, 'vocab.json') vocab_path = os.path.join(self.model_path, 'vocab.json')
merges_path = os.path.join(self.model_path, 'merges.txt') merges_path = os.path.join(self.model_path, 'merges.txt')
with open(config_path, 'r', encoding='utf8') as f:
self.config = json.load(f)
with open(vocab_path, 'r', encoding='utf8') as f: with open(vocab_path, 'r', encoding='utf8') as f:
vocab = json.load(f) vocab = json.load(f)
with open(merges_path, 'r', encoding='utf8') as f: with open(merges_path, 'r', encoding='utf8') as f:
@ -29,8 +26,7 @@ class MinDalleBase:
print("tokenizing text") print("tokenizing text")
tokens = self.tokenizer.tokenize(text) tokens = self.tokenizer.tokenize(text)
print("text tokens", tokens) print("text tokens", tokens)
text_token_count = self.config['max_text_length'] text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
text_tokens[0, :2] = [tokens[0], tokens[-1]] text_tokens[0, :2] = [tokens[0], tokens[-1]]
text_tokens[1, :len(tokens)] = tokens text_tokens[1, :len(tokens)] = tokens
return text_tokens return text_tokens

@ -26,26 +26,25 @@ class MinDalleFlax(MinDalleBase):
def init_encoder(self): def init_encoder(self):
print("initializing DalleBartEncoderFlax") print("initializing DalleBartEncoderFlax")
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax( self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = self.config['encoder_attention_heads'], attention_head_count = 32 if self.is_mega else 16,
embed_count = self.config['d_model'], embed_count = 2048 if self.is_mega else 1024,
glu_embed_count = self.config['encoder_ffn_dim'], glu_embed_count = 4096 if self.is_mega else 2730,
text_token_count = self.config['max_text_length'], text_token_count = 64,
text_vocab_count = self.config['encoder_vocab_size'], text_vocab_count = 50272 if self.is_mega else 50264,
layer_count = self.config['encoder_layers'] layer_count = 24 if self.is_mega else 12
).bind({'params': self.model_params.pop('encoder')}) ).bind({'params': self.model_params.pop('encoder')})
def init_decoder(self): def init_decoder(self):
print("initializing DalleBartDecoderFlax") print("initializing DalleBartDecoderFlax")
self.decoder = DalleBartDecoderFlax( self.decoder = DalleBartDecoderFlax(
image_token_count = self.config['image_length'], image_token_count = 256,
text_token_count = self.config['max_text_length'], image_vocab_count = 16415 if self.is_mega else 16384,
image_vocab_count = self.config['image_vocab_size'], attention_head_count = 32 if self.is_mega else 16,
attention_head_count = self.config['decoder_attention_heads'], embed_count = 2048 if self.is_mega else 1024,
embed_count = self.config['d_model'], glu_embed_count = 4096 if self.is_mega else 2730,
glu_embed_count = self.config['decoder_ffn_dim'], layer_count = 24 if self.is_mega else 12,
layer_count = self.config['decoder_layers'], start_token = 16415 if self.is_mega else 16384
start_token = self.config['decoder_start_token_id']
) )

@ -37,12 +37,12 @@ class MinDalleTorch(MinDalleBase):
def init_encoder(self): def init_encoder(self):
print("initializing DalleBartEncoderTorch") print("initializing DalleBartEncoderTorch")
self.encoder = DalleBartEncoderTorch( self.encoder = DalleBartEncoderTorch(
layer_count = self.config['encoder_layers'], attention_head_count = 32 if self.is_mega else 16,
embed_count = self.config['d_model'], embed_count = 2048 if self.is_mega else 1024,
attention_head_count = self.config['encoder_attention_heads'], glu_embed_count = 4096 if self.is_mega else 2730,
text_vocab_count = self.config['encoder_vocab_size'], text_token_count = 64,
text_token_count = self.config['max_text_length'], text_vocab_count = 50272 if self.is_mega else 50264,
glu_embed_count = self.config['encoder_ffn_dim'] layer_count = 24 if self.is_mega else 12
) )
params = torch.load(self.encoder_params_path) params = torch.load(self.encoder_params_path)
self.encoder.load_state_dict(params, strict=False) self.encoder.load_state_dict(params, strict=False)
@ -53,16 +53,15 @@ class MinDalleTorch(MinDalleBase):
def init_decoder(self): def init_decoder(self):
print("initializing DalleBartDecoderTorch") print("initializing DalleBartDecoderTorch")
self.decoder = DalleBartDecoderTorch( self.decoder = DalleBartDecoderTorch(
image_vocab_size = self.config['image_vocab_size'],
image_token_count = self.config['image_length'],
sample_token_count = self.token_count, sample_token_count = self.token_count,
embed_count = self.config['d_model'], image_token_count = 256,
attention_head_count = self.config['decoder_attention_heads'], image_vocab_count = 16415 if self.is_mega else 16384,
glu_embed_count = self.config['decoder_ffn_dim'], attention_head_count = 32 if self.is_mega else 16,
layer_count = self.config['decoder_layers'], embed_count = 2048 if self.is_mega else 1024,
batch_count = 2, glu_embed_count = 4096 if self.is_mega else 2730,
start_token = self.config['decoder_start_token_id'], layer_count = 24 if self.is_mega else 12,
is_verbose = True start_token = 16415 if self.is_mega else 16384,
batch_count = 2
) )
params = torch.load(self.decoder_params_path) params = torch.load(self.decoder_params_path)
self.decoder.load_state_dict(params, strict=False) self.decoder.load_state_dict(params, strict=False)

@ -130,7 +130,6 @@ def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
class DalleBartDecoderFlax(nn.Module): class DalleBartDecoderFlax(nn.Module):
image_token_count: int image_token_count: int
text_token_count: int
image_vocab_count: int image_vocab_count: int
attention_head_count: int attention_head_count: int
embed_count: int embed_count: int

@ -109,7 +109,7 @@ class DecoderLayerTorch(nn.Module):
class DalleBartDecoderTorch(nn.Module): class DalleBartDecoderTorch(nn.Module):
def __init__( def __init__(
self, self,
image_vocab_size: int, image_vocab_count: int,
image_token_count: int, image_token_count: int,
sample_token_count: int, sample_token_count: int,
embed_count: int, embed_count: int,
@ -117,16 +117,14 @@ class DalleBartDecoderTorch(nn.Module):
glu_embed_count: int, glu_embed_count: int,
layer_count: int, layer_count: int,
batch_count: int, batch_count: int,
start_token: int, start_token: int
is_verbose: bool
): ):
super().__init__() super().__init__()
self.is_verbose = is_verbose
self.layer_count = layer_count self.layer_count = layer_count
self.sample_token_count = sample_token_count self.sample_token_count = sample_token_count
self.condition_factor = 10.0 self.condition_factor = 10.0
self.image_token_count = image_token_count self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count) self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count)
self.layers: List[DecoderLayerTorch] = nn.ModuleList([ self.layers: List[DecoderLayerTorch] = nn.ModuleList([
DecoderLayerTorch( DecoderLayerTorch(
@ -139,7 +137,7 @@ class DalleBartDecoderTorch(nn.Module):
]) ])
self.layernorm_embedding = nn.LayerNorm(embed_count) self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count) self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False) self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
self.attention_state_shape = ( self.attention_state_shape = (
layer_count, layer_count,
2 * batch_count, 2 * batch_count,

2
setup_torch.sh vendored

@ -7,14 +7,12 @@ pip install torch
mkdir -p ./pretrained/dalle_bart_mega/ mkdir -p ./pretrained/dalle_bart_mega/
curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab.json -L --output ./pretrained/dalle_bart_mega/vocab.json curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab.json -L --output ./pretrained/dalle_bart_mega/vocab.json
curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges.txt -L --output ./pretrained/dalle_bart_mega/merges.txt curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges.txt -L --output ./pretrained/dalle_bart_mega/merges.txt
curl https://huggingface.co/kuprel/min-dalle/resolve/main/config.json -L --output ./pretrained/dalle_bart_mega/config.json
curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder.pt -L --output ./pretrained/dalle_bart_mega/encoder.pt curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder.pt -L --output ./pretrained/dalle_bart_mega/encoder.pt
curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder.pt -L --output ./pretrained/dalle_bart_mega/decoder.pt curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder.pt -L --output ./pretrained/dalle_bart_mega/decoder.pt
mkdir -p ./pretrained/dalle_bart_mini/ mkdir -p ./pretrained/dalle_bart_mini/
curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab_mini.json -L --output ./pretrained/dalle_bart_mini/vocab.json curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab_mini.json -L --output ./pretrained/dalle_bart_mini/vocab.json
curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges_mini.txt -L --output ./pretrained/dalle_bart_mini/merges.txt curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges_mini.txt -L --output ./pretrained/dalle_bart_mini/merges.txt
curl https://huggingface.co/kuprel/min-dalle/resolve/main/config_mini.json -L --output ./pretrained/dalle_bart_mini/config.json
curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder_mini.pt -L --output ./pretrained/dalle_bart_mini/encoder.pt curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder_mini.pt -L --output ./pretrained/dalle_bart_mini/encoder.pt
curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder_mini.pt -L --output ./pretrained/dalle_bart_mini/decoder.pt curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder_mini.pt -L --output ./pretrained/dalle_bart_mini/decoder.pt

Loading…
Cancel
Save