remove config.json dependency, default to torch in image_from_text.py

This commit is contained in:
Brett Kuprel
2022-07-01 12:03:37 -04:00
parent 4404e70764
commit 85f5866eff
10 changed files with 52 additions and 64 deletions

View File

@@ -11,12 +11,9 @@ class MinDalleBase:
self.model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(self.model_path))
config_path = os.path.join(self.model_path, 'config.json')
vocab_path = os.path.join(self.model_path, 'vocab.json')
merges_path = os.path.join(self.model_path, 'merges.txt')
with open(config_path, 'r', encoding='utf8') as f:
self.config = json.load(f)
with open(vocab_path, 'r', encoding='utf8') as f:
vocab = json.load(f)
with open(merges_path, 'r', encoding='utf8') as f:
@@ -29,8 +26,7 @@ class MinDalleBase:
print("tokenizing text")
tokens = self.tokenizer.tokenize(text)
print("text tokens", tokens)
text_token_count = self.config['max_text_length']
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
text_tokens[0, :2] = [tokens[0], tokens[-1]]
text_tokens[1, :len(tokens)] = tokens
return text_tokens

View File

@@ -26,26 +26,25 @@ class MinDalleFlax(MinDalleBase):
def init_encoder(self):
print("initializing DalleBartEncoderFlax")
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = self.config['encoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['encoder_ffn_dim'],
text_token_count = self.config['max_text_length'],
text_vocab_count = self.config['encoder_vocab_size'],
layer_count = self.config['encoder_layers']
attention_head_count = 32 if self.is_mega else 16,
embed_count = 2048 if self.is_mega else 1024,
glu_embed_count = 4096 if self.is_mega else 2730,
text_token_count = 64,
text_vocab_count = 50272 if self.is_mega else 50264,
layer_count = 24 if self.is_mega else 12
).bind({'params': self.model_params.pop('encoder')})
def init_decoder(self):
print("initializing DalleBartDecoderFlax")
self.decoder = DalleBartDecoderFlax(
image_token_count = self.config['image_length'],
text_token_count = self.config['max_text_length'],
image_vocab_count = self.config['image_vocab_size'],
attention_head_count = self.config['decoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
start_token = self.config['decoder_start_token_id']
image_token_count = 256,
image_vocab_count = 16415 if self.is_mega else 16384,
attention_head_count = 32 if self.is_mega else 16,
embed_count = 2048 if self.is_mega else 1024,
glu_embed_count = 4096 if self.is_mega else 2730,
layer_count = 24 if self.is_mega else 12,
start_token = 16415 if self.is_mega else 16384
)

View File

@@ -37,12 +37,12 @@ class MinDalleTorch(MinDalleBase):
def init_encoder(self):
print("initializing DalleBartEncoderTorch")
self.encoder = DalleBartEncoderTorch(
layer_count = self.config['encoder_layers'],
embed_count = self.config['d_model'],
attention_head_count = self.config['encoder_attention_heads'],
text_vocab_count = self.config['encoder_vocab_size'],
text_token_count = self.config['max_text_length'],
glu_embed_count = self.config['encoder_ffn_dim']
attention_head_count = 32 if self.is_mega else 16,
embed_count = 2048 if self.is_mega else 1024,
glu_embed_count = 4096 if self.is_mega else 2730,
text_token_count = 64,
text_vocab_count = 50272 if self.is_mega else 50264,
layer_count = 24 if self.is_mega else 12
)
params = torch.load(self.encoder_params_path)
self.encoder.load_state_dict(params, strict=False)
@@ -53,16 +53,15 @@ class MinDalleTorch(MinDalleBase):
def init_decoder(self):
print("initializing DalleBartDecoderTorch")
self.decoder = DalleBartDecoderTorch(
image_vocab_size = self.config['image_vocab_size'],
image_token_count = self.config['image_length'],
sample_token_count = self.token_count,
embed_count = self.config['d_model'],
attention_head_count = self.config['decoder_attention_heads'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
batch_count = 2,
start_token = self.config['decoder_start_token_id'],
is_verbose = True
image_token_count = 256,
image_vocab_count = 16415 if self.is_mega else 16384,
attention_head_count = 32 if self.is_mega else 16,
embed_count = 2048 if self.is_mega else 1024,
glu_embed_count = 4096 if self.is_mega else 2730,
layer_count = 24 if self.is_mega else 12,
start_token = 16415 if self.is_mega else 16384,
batch_count = 2
)
params = torch.load(self.decoder_params_path)
self.decoder.load_state_dict(params, strict=False)

View File

@@ -130,7 +130,6 @@ def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
class DalleBartDecoderFlax(nn.Module):
image_token_count: int
text_token_count: int
image_vocab_count: int
attention_head_count: int
embed_count: int

View File

@@ -109,7 +109,7 @@ class DecoderLayerTorch(nn.Module):
class DalleBartDecoderTorch(nn.Module):
def __init__(
self,
image_vocab_size: int,
image_vocab_count: int,
image_token_count: int,
sample_token_count: int,
embed_count: int,
@@ -117,16 +117,14 @@ class DalleBartDecoderTorch(nn.Module):
glu_embed_count: int,
layer_count: int,
batch_count: int,
start_token: int,
is_verbose: bool
start_token: int
):
super().__init__()
self.is_verbose = is_verbose
self.layer_count = layer_count
self.sample_token_count = sample_token_count
self.condition_factor = 10.0
self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count)
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count)
self.layers: List[DecoderLayerTorch] = nn.ModuleList([
DecoderLayerTorch(
@@ -139,7 +137,7 @@ class DalleBartDecoderTorch(nn.Module):
])
self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
self.attention_state_shape = (
layer_count,
2 * batch_count,