simplified MinDalleTorch
This commit is contained in:
parent
56f0563ad1
commit
18c72ed34d
|
@ -19,45 +19,53 @@ class MinDalleTorch:
|
||||||
self,
|
self,
|
||||||
is_mega: bool,
|
is_mega: bool,
|
||||||
is_reusable: bool = True,
|
is_reusable: bool = True,
|
||||||
token_count: int = 256
|
sample_token_count: int = 256
|
||||||
):
|
):
|
||||||
print("initializing MinDalleTorch")
|
print("initializing MinDalleTorch")
|
||||||
self.is_mega = is_mega
|
self.is_reusable = is_reusable
|
||||||
|
self.sample_token_count = sample_token_count
|
||||||
|
self.batch_count = 2
|
||||||
|
self.text_token_count = 64
|
||||||
|
self.image_token_count = 256
|
||||||
|
self.layer_count = 24 if is_mega else 12
|
||||||
|
self.attention_head_count = 32 if is_mega else 16
|
||||||
|
self.embed_count = 2048 if is_mega else 1024
|
||||||
|
self.glu_embed_count = 4096 if is_mega else 2730
|
||||||
|
self.text_vocab_count = 50272 if is_mega else 50264
|
||||||
|
self.image_vocab_count = 16415 if is_mega else 16384
|
||||||
|
|
||||||
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
||||||
self.model_path = os.path.join('pretrained', model_name)
|
self.model_path = os.path.join('pretrained', model_name)
|
||||||
|
|
||||||
print("reading files from {}".format(self.model_path))
|
|
||||||
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
|
||||||
merges_path = os.path.join(self.model_path, 'merges.txt')
|
|
||||||
|
|
||||||
with open(vocab_path, 'r', encoding='utf8') as f:
|
|
||||||
vocab = json.load(f)
|
|
||||||
with open(merges_path, 'r', encoding='utf8') as f:
|
|
||||||
merges = f.read().split("\n")[1:-1]
|
|
||||||
|
|
||||||
self.tokenizer = TextTokenizer(vocab, merges)
|
|
||||||
self.is_reusable = is_reusable
|
|
||||||
self.token_count = token_count
|
|
||||||
|
|
||||||
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
|
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
|
||||||
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
|
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
|
||||||
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt')
|
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt')
|
||||||
|
|
||||||
|
self.init_tokenizer()
|
||||||
if is_reusable:
|
if is_reusable:
|
||||||
self.init_encoder()
|
self.init_encoder()
|
||||||
self.init_decoder()
|
self.init_decoder()
|
||||||
self.init_detokenizer()
|
self.init_detokenizer()
|
||||||
|
|
||||||
|
def init_tokenizer(self):
|
||||||
|
print("reading files from {}".format(self.model_path))
|
||||||
|
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
||||||
|
merges_path = os.path.join(self.model_path, 'merges.txt')
|
||||||
|
with open(vocab_path, 'r', encoding='utf8') as f:
|
||||||
|
vocab = json.load(f)
|
||||||
|
with open(merges_path, 'r', encoding='utf8') as f:
|
||||||
|
merges = f.read().split("\n")[1:-1]
|
||||||
|
self.tokenizer = TextTokenizer(vocab, merges)
|
||||||
|
|
||||||
|
|
||||||
def init_encoder(self):
|
def init_encoder(self):
|
||||||
print("initializing DalleBartEncoderTorch")
|
print("initializing DalleBartEncoderTorch")
|
||||||
self.encoder = DalleBartEncoderTorch(
|
self.encoder = DalleBartEncoderTorch(
|
||||||
attention_head_count = 32 if self.is_mega else 16,
|
attention_head_count = self.attention_head_count,
|
||||||
embed_count = 2048 if self.is_mega else 1024,
|
embed_count = self.embed_count,
|
||||||
glu_embed_count = 4096 if self.is_mega else 2730,
|
glu_embed_count = self.glu_embed_count,
|
||||||
text_token_count = 64,
|
text_token_count = self.text_token_count,
|
||||||
text_vocab_count = 50272 if self.is_mega else 50264,
|
text_vocab_count = self.text_vocab_count,
|
||||||
layer_count = 24 if self.is_mega else 12
|
layer_count = self.layer_count
|
||||||
)
|
)
|
||||||
params = torch.load(self.encoder_params_path)
|
params = torch.load(self.encoder_params_path)
|
||||||
self.encoder.load_state_dict(params, strict=False)
|
self.encoder.load_state_dict(params, strict=False)
|
||||||
|
@ -68,15 +76,15 @@ class MinDalleTorch:
|
||||||
def init_decoder(self):
|
def init_decoder(self):
|
||||||
print("initializing DalleBartDecoderTorch")
|
print("initializing DalleBartDecoderTorch")
|
||||||
self.decoder = DalleBartDecoderTorch(
|
self.decoder = DalleBartDecoderTorch(
|
||||||
sample_token_count = self.token_count,
|
sample_token_count = self.sample_token_count,
|
||||||
image_token_count = 256,
|
image_token_count = self.image_token_count,
|
||||||
image_vocab_count = 16415 if self.is_mega else 16384,
|
image_vocab_count = self.image_vocab_count,
|
||||||
attention_head_count = 32 if self.is_mega else 16,
|
attention_head_count = self.attention_head_count,
|
||||||
embed_count = 2048 if self.is_mega else 1024,
|
embed_count = self.embed_count,
|
||||||
glu_embed_count = 4096 if self.is_mega else 2730,
|
glu_embed_count = self.glu_embed_count,
|
||||||
layer_count = 24 if self.is_mega else 12,
|
layer_count = self.layer_count,
|
||||||
start_token = 16415 if self.is_mega else 16384,
|
start_token = self.image_vocab_count,
|
||||||
batch_count = 2
|
batch_count = self.batch_count
|
||||||
)
|
)
|
||||||
params = torch.load(self.decoder_params_path)
|
params = torch.load(self.decoder_params_path)
|
||||||
self.decoder.load_state_dict(params, strict=False)
|
self.decoder.load_state_dict(params, strict=False)
|
||||||
|
@ -93,18 +101,14 @@ class MinDalleTorch:
|
||||||
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
|
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
|
||||||
|
|
||||||
|
|
||||||
def tokenize_text(self, text: str) -> numpy.ndarray:
|
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
||||||
print("tokenizing text")
|
print("tokenizing text")
|
||||||
tokens = self.tokenizer.tokenize(text)
|
tokens = self.tokenizer.tokenize(text)
|
||||||
print("text tokens", tokens)
|
print("text tokens", tokens)
|
||||||
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
|
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
|
||||||
text_tokens[0, :2] = [tokens[0], tokens[-1]]
|
text_tokens[0, :2] = [tokens[0], tokens[-1]]
|
||||||
text_tokens[1, :len(tokens)] = tokens
|
text_tokens[1, :len(tokens)] = tokens
|
||||||
return text_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
|
||||||
text_tokens = self.tokenize_text(text)
|
|
||||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||||
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user