remove config.json dependency, default to torch in image_from_text.py
This commit is contained in:
parent
4404e70764
commit
85f5866eff
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -2,7 +2,7 @@
|
||||||
**/.cache/
|
**/.cache/
|
||||||
**/*.pkl
|
**/*.pkl
|
||||||
**/.DS*
|
**/.DS*
|
||||||
*.pt
|
**/*.pt
|
||||||
*.mlpackage
|
*.mlpackage
|
||||||
**/*.ckpt
|
**/*.ckpt
|
||||||
.vscode
|
.vscode
|
||||||
|
|
6
README.md
vendored
6
README.md
vendored
|
@ -22,19 +22,19 @@ Use the python script `image_from_text.py` to generate images from the command l
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
```
|
```
|
||||||
python image_from_text.py --text='artificial intelligence' --torch
|
python image_from_text.py --text='artificial intelligence' --seed=7
|
||||||
```
|
```
|
||||||
![Alien](examples/artificial_intelligence.png)
|
![Alien](examples/artificial_intelligence.png)
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
python image_from_text.py --text='a comfy chair that looks like an avocado' --torch --mega --seed=10
|
python image_from_text.py --text='a comfy chair that looks like an avocado' --mega --seed=10
|
||||||
```
|
```
|
||||||
![Avocado Armchair](examples/avocado_armchair.png)
|
![Avocado Armchair](examples/avocado_armchair.png)
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
python image_from_text.py --text='court sketch of godzilla on trial' --torch --mega --seed=40
|
python image_from_text.py --text='court sketch of godzilla on trial' --mega --seed=40
|
||||||
```
|
```
|
||||||
|
|
||||||
![Godzilla Trial](examples/godzilla_trial.png)
|
![Godzilla Trial](examples/godzilla_trial.png)
|
||||||
|
|
|
@ -11,9 +11,9 @@ parser.add_argument('--no-mega', dest='mega', action='store_false')
|
||||||
parser.set_defaults(mega=False)
|
parser.set_defaults(mega=False)
|
||||||
parser.add_argument('--torch', action='store_true')
|
parser.add_argument('--torch', action='store_true')
|
||||||
parser.add_argument('--no-torch', dest='torch', action='store_false')
|
parser.add_argument('--no-torch', dest='torch', action='store_false')
|
||||||
parser.set_defaults(torch=False)
|
parser.set_defaults(torch=True)
|
||||||
parser.add_argument('--text', type=str, default='alien life')
|
parser.add_argument('--text', type=str, default='cat')
|
||||||
parser.add_argument('--seed', type=int, default=7)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--image_path', type=str, default='generated')
|
parser.add_argument('--image_path', type=str, default='generated')
|
||||||
parser.add_argument('--token_count', type=int, default=256) # for debugging
|
parser.add_argument('--token_count', type=int, default=256) # for debugging
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ def generate_image(
|
||||||
if is_torch:
|
if is_torch:
|
||||||
image_generator = MinDalleTorch(is_mega, is_reusable, token_count)
|
image_generator = MinDalleTorch(is_mega, is_reusable, token_count)
|
||||||
|
|
||||||
if token_count < image_generator.config['image_length']:
|
if token_count < 256:
|
||||||
image_tokens = image_generator.generate_image_tokens(text, seed)
|
image_tokens = image_generator.generate_image_tokens(text, seed)
|
||||||
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
|
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
|
||||||
return
|
return
|
||||||
|
|
25
min_dalle.ipynb
vendored
25
min_dalle.ipynb
vendored
File diff suppressed because one or more lines are too long
|
@ -11,12 +11,9 @@ class MinDalleBase:
|
||||||
self.model_path = os.path.join('pretrained', model_name)
|
self.model_path = os.path.join('pretrained', model_name)
|
||||||
|
|
||||||
print("reading files from {}".format(self.model_path))
|
print("reading files from {}".format(self.model_path))
|
||||||
config_path = os.path.join(self.model_path, 'config.json')
|
|
||||||
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
||||||
merges_path = os.path.join(self.model_path, 'merges.txt')
|
merges_path = os.path.join(self.model_path, 'merges.txt')
|
||||||
|
|
||||||
with open(config_path, 'r', encoding='utf8') as f:
|
|
||||||
self.config = json.load(f)
|
|
||||||
with open(vocab_path, 'r', encoding='utf8') as f:
|
with open(vocab_path, 'r', encoding='utf8') as f:
|
||||||
vocab = json.load(f)
|
vocab = json.load(f)
|
||||||
with open(merges_path, 'r', encoding='utf8') as f:
|
with open(merges_path, 'r', encoding='utf8') as f:
|
||||||
|
@ -29,8 +26,7 @@ class MinDalleBase:
|
||||||
print("tokenizing text")
|
print("tokenizing text")
|
||||||
tokens = self.tokenizer.tokenize(text)
|
tokens = self.tokenizer.tokenize(text)
|
||||||
print("text tokens", tokens)
|
print("text tokens", tokens)
|
||||||
text_token_count = self.config['max_text_length']
|
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
|
||||||
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
|
|
||||||
text_tokens[0, :2] = [tokens[0], tokens[-1]]
|
text_tokens[0, :2] = [tokens[0], tokens[-1]]
|
||||||
text_tokens[1, :len(tokens)] = tokens
|
text_tokens[1, :len(tokens)] = tokens
|
||||||
return text_tokens
|
return text_tokens
|
|
@ -26,26 +26,25 @@ class MinDalleFlax(MinDalleBase):
|
||||||
def init_encoder(self):
|
def init_encoder(self):
|
||||||
print("initializing DalleBartEncoderFlax")
|
print("initializing DalleBartEncoderFlax")
|
||||||
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
|
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
|
||||||
attention_head_count = self.config['encoder_attention_heads'],
|
attention_head_count = 32 if self.is_mega else 16,
|
||||||
embed_count = self.config['d_model'],
|
embed_count = 2048 if self.is_mega else 1024,
|
||||||
glu_embed_count = self.config['encoder_ffn_dim'],
|
glu_embed_count = 4096 if self.is_mega else 2730,
|
||||||
text_token_count = self.config['max_text_length'],
|
text_token_count = 64,
|
||||||
text_vocab_count = self.config['encoder_vocab_size'],
|
text_vocab_count = 50272 if self.is_mega else 50264,
|
||||||
layer_count = self.config['encoder_layers']
|
layer_count = 24 if self.is_mega else 12
|
||||||
).bind({'params': self.model_params.pop('encoder')})
|
).bind({'params': self.model_params.pop('encoder')})
|
||||||
|
|
||||||
|
|
||||||
def init_decoder(self):
|
def init_decoder(self):
|
||||||
print("initializing DalleBartDecoderFlax")
|
print("initializing DalleBartDecoderFlax")
|
||||||
self.decoder = DalleBartDecoderFlax(
|
self.decoder = DalleBartDecoderFlax(
|
||||||
image_token_count = self.config['image_length'],
|
image_token_count = 256,
|
||||||
text_token_count = self.config['max_text_length'],
|
image_vocab_count = 16415 if self.is_mega else 16384,
|
||||||
image_vocab_count = self.config['image_vocab_size'],
|
attention_head_count = 32 if self.is_mega else 16,
|
||||||
attention_head_count = self.config['decoder_attention_heads'],
|
embed_count = 2048 if self.is_mega else 1024,
|
||||||
embed_count = self.config['d_model'],
|
glu_embed_count = 4096 if self.is_mega else 2730,
|
||||||
glu_embed_count = self.config['decoder_ffn_dim'],
|
layer_count = 24 if self.is_mega else 12,
|
||||||
layer_count = self.config['decoder_layers'],
|
start_token = 16415 if self.is_mega else 16384
|
||||||
start_token = self.config['decoder_start_token_id']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,12 +37,12 @@ class MinDalleTorch(MinDalleBase):
|
||||||
def init_encoder(self):
|
def init_encoder(self):
|
||||||
print("initializing DalleBartEncoderTorch")
|
print("initializing DalleBartEncoderTorch")
|
||||||
self.encoder = DalleBartEncoderTorch(
|
self.encoder = DalleBartEncoderTorch(
|
||||||
layer_count = self.config['encoder_layers'],
|
attention_head_count = 32 if self.is_mega else 16,
|
||||||
embed_count = self.config['d_model'],
|
embed_count = 2048 if self.is_mega else 1024,
|
||||||
attention_head_count = self.config['encoder_attention_heads'],
|
glu_embed_count = 4096 if self.is_mega else 2730,
|
||||||
text_vocab_count = self.config['encoder_vocab_size'],
|
text_token_count = 64,
|
||||||
text_token_count = self.config['max_text_length'],
|
text_vocab_count = 50272 if self.is_mega else 50264,
|
||||||
glu_embed_count = self.config['encoder_ffn_dim']
|
layer_count = 24 if self.is_mega else 12
|
||||||
)
|
)
|
||||||
params = torch.load(self.encoder_params_path)
|
params = torch.load(self.encoder_params_path)
|
||||||
self.encoder.load_state_dict(params, strict=False)
|
self.encoder.load_state_dict(params, strict=False)
|
||||||
|
@ -53,16 +53,15 @@ class MinDalleTorch(MinDalleBase):
|
||||||
def init_decoder(self):
|
def init_decoder(self):
|
||||||
print("initializing DalleBartDecoderTorch")
|
print("initializing DalleBartDecoderTorch")
|
||||||
self.decoder = DalleBartDecoderTorch(
|
self.decoder = DalleBartDecoderTorch(
|
||||||
image_vocab_size = self.config['image_vocab_size'],
|
|
||||||
image_token_count = self.config['image_length'],
|
|
||||||
sample_token_count = self.token_count,
|
sample_token_count = self.token_count,
|
||||||
embed_count = self.config['d_model'],
|
image_token_count = 256,
|
||||||
attention_head_count = self.config['decoder_attention_heads'],
|
image_vocab_count = 16415 if self.is_mega else 16384,
|
||||||
glu_embed_count = self.config['decoder_ffn_dim'],
|
attention_head_count = 32 if self.is_mega else 16,
|
||||||
layer_count = self.config['decoder_layers'],
|
embed_count = 2048 if self.is_mega else 1024,
|
||||||
batch_count = 2,
|
glu_embed_count = 4096 if self.is_mega else 2730,
|
||||||
start_token = self.config['decoder_start_token_id'],
|
layer_count = 24 if self.is_mega else 12,
|
||||||
is_verbose = True
|
start_token = 16415 if self.is_mega else 16384,
|
||||||
|
batch_count = 2
|
||||||
)
|
)
|
||||||
params = torch.load(self.decoder_params_path)
|
params = torch.load(self.decoder_params_path)
|
||||||
self.decoder.load_state_dict(params, strict=False)
|
self.decoder.load_state_dict(params, strict=False)
|
||||||
|
|
|
@ -130,7 +130,6 @@ def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
|
||||||
|
|
||||||
class DalleBartDecoderFlax(nn.Module):
|
class DalleBartDecoderFlax(nn.Module):
|
||||||
image_token_count: int
|
image_token_count: int
|
||||||
text_token_count: int
|
|
||||||
image_vocab_count: int
|
image_vocab_count: int
|
||||||
attention_head_count: int
|
attention_head_count: int
|
||||||
embed_count: int
|
embed_count: int
|
||||||
|
|
|
@ -109,7 +109,7 @@ class DecoderLayerTorch(nn.Module):
|
||||||
class DalleBartDecoderTorch(nn.Module):
|
class DalleBartDecoderTorch(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_vocab_size: int,
|
image_vocab_count: int,
|
||||||
image_token_count: int,
|
image_token_count: int,
|
||||||
sample_token_count: int,
|
sample_token_count: int,
|
||||||
embed_count: int,
|
embed_count: int,
|
||||||
|
@ -117,16 +117,14 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
glu_embed_count: int,
|
glu_embed_count: int,
|
||||||
layer_count: int,
|
layer_count: int,
|
||||||
batch_count: int,
|
batch_count: int,
|
||||||
start_token: int,
|
start_token: int
|
||||||
is_verbose: bool
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.is_verbose = is_verbose
|
|
||||||
self.layer_count = layer_count
|
self.layer_count = layer_count
|
||||||
self.sample_token_count = sample_token_count
|
self.sample_token_count = sample_token_count
|
||||||
self.condition_factor = 10.0
|
self.condition_factor = 10.0
|
||||||
self.image_token_count = image_token_count
|
self.image_token_count = image_token_count
|
||||||
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count)
|
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
|
||||||
self.embed_positions = nn.Embedding(image_token_count, embed_count)
|
self.embed_positions = nn.Embedding(image_token_count, embed_count)
|
||||||
self.layers: List[DecoderLayerTorch] = nn.ModuleList([
|
self.layers: List[DecoderLayerTorch] = nn.ModuleList([
|
||||||
DecoderLayerTorch(
|
DecoderLayerTorch(
|
||||||
|
@ -139,7 +137,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
])
|
])
|
||||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||||
self.final_ln = nn.LayerNorm(embed_count)
|
self.final_ln = nn.LayerNorm(embed_count)
|
||||||
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
|
self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
|
||||||
self.attention_state_shape = (
|
self.attention_state_shape = (
|
||||||
layer_count,
|
layer_count,
|
||||||
2 * batch_count,
|
2 * batch_count,
|
||||||
|
|
2
setup_torch.sh
vendored
2
setup_torch.sh
vendored
|
@ -7,14 +7,12 @@ pip install torch
|
||||||
mkdir -p ./pretrained/dalle_bart_mega/
|
mkdir -p ./pretrained/dalle_bart_mega/
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab.json -L --output ./pretrained/dalle_bart_mega/vocab.json
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab.json -L --output ./pretrained/dalle_bart_mega/vocab.json
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges.txt -L --output ./pretrained/dalle_bart_mega/merges.txt
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges.txt -L --output ./pretrained/dalle_bart_mega/merges.txt
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/config.json -L --output ./pretrained/dalle_bart_mega/config.json
|
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder.pt -L --output ./pretrained/dalle_bart_mega/encoder.pt
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder.pt -L --output ./pretrained/dalle_bart_mega/encoder.pt
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder.pt -L --output ./pretrained/dalle_bart_mega/decoder.pt
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder.pt -L --output ./pretrained/dalle_bart_mega/decoder.pt
|
||||||
|
|
||||||
mkdir -p ./pretrained/dalle_bart_mini/
|
mkdir -p ./pretrained/dalle_bart_mini/
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab_mini.json -L --output ./pretrained/dalle_bart_mini/vocab.json
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab_mini.json -L --output ./pretrained/dalle_bart_mini/vocab.json
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges_mini.txt -L --output ./pretrained/dalle_bart_mini/merges.txt
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges_mini.txt -L --output ./pretrained/dalle_bart_mini/merges.txt
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/config_mini.json -L --output ./pretrained/dalle_bart_mini/config.json
|
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder_mini.pt -L --output ./pretrained/dalle_bart_mini/encoder.pt
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder_mini.pt -L --output ./pretrained/dalle_bart_mini/encoder.pt
|
||||||
curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder_mini.pt -L --output ./pretrained/dalle_bart_mini/decoder.pt
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder_mini.pt -L --output ./pretrained/dalle_bart_mini/decoder.pt
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user