simplified MinDalleTorch
This commit is contained in:
parent
56f0563ad1
commit
18c72ed34d
|
@ -19,45 +19,53 @@ class MinDalleTorch:
|
|||
self,
|
||||
is_mega: bool,
|
||||
is_reusable: bool = True,
|
||||
token_count: int = 256
|
||||
sample_token_count: int = 256
|
||||
):
|
||||
print("initializing MinDalleTorch")
|
||||
self.is_mega = is_mega
|
||||
self.is_reusable = is_reusable
|
||||
self.sample_token_count = sample_token_count
|
||||
self.batch_count = 2
|
||||
self.text_token_count = 64
|
||||
self.image_token_count = 256
|
||||
self.layer_count = 24 if is_mega else 12
|
||||
self.attention_head_count = 32 if is_mega else 16
|
||||
self.embed_count = 2048 if is_mega else 1024
|
||||
self.glu_embed_count = 4096 if is_mega else 2730
|
||||
self.text_vocab_count = 50272 if is_mega else 50264
|
||||
self.image_vocab_count = 16415 if is_mega else 16384
|
||||
|
||||
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
||||
self.model_path = os.path.join('pretrained', model_name)
|
||||
|
||||
print("reading files from {}".format(self.model_path))
|
||||
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
||||
merges_path = os.path.join(self.model_path, 'merges.txt')
|
||||
|
||||
with open(vocab_path, 'r', encoding='utf8') as f:
|
||||
vocab = json.load(f)
|
||||
with open(merges_path, 'r', encoding='utf8') as f:
|
||||
merges = f.read().split("\n")[1:-1]
|
||||
|
||||
self.tokenizer = TextTokenizer(vocab, merges)
|
||||
self.is_reusable = is_reusable
|
||||
self.token_count = token_count
|
||||
|
||||
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
|
||||
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
|
||||
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt')
|
||||
|
||||
self.init_tokenizer()
|
||||
if is_reusable:
|
||||
self.init_encoder()
|
||||
self.init_decoder()
|
||||
self.init_detokenizer()
|
||||
|
||||
def init_tokenizer(self):
|
||||
print("reading files from {}".format(self.model_path))
|
||||
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
||||
merges_path = os.path.join(self.model_path, 'merges.txt')
|
||||
with open(vocab_path, 'r', encoding='utf8') as f:
|
||||
vocab = json.load(f)
|
||||
with open(merges_path, 'r', encoding='utf8') as f:
|
||||
merges = f.read().split("\n")[1:-1]
|
||||
self.tokenizer = TextTokenizer(vocab, merges)
|
||||
|
||||
|
||||
def init_encoder(self):
|
||||
print("initializing DalleBartEncoderTorch")
|
||||
self.encoder = DalleBartEncoderTorch(
|
||||
attention_head_count = 32 if self.is_mega else 16,
|
||||
embed_count = 2048 if self.is_mega else 1024,
|
||||
glu_embed_count = 4096 if self.is_mega else 2730,
|
||||
text_token_count = 64,
|
||||
text_vocab_count = 50272 if self.is_mega else 50264,
|
||||
layer_count = 24 if self.is_mega else 12
|
||||
attention_head_count = self.attention_head_count,
|
||||
embed_count = self.embed_count,
|
||||
glu_embed_count = self.glu_embed_count,
|
||||
text_token_count = self.text_token_count,
|
||||
text_vocab_count = self.text_vocab_count,
|
||||
layer_count = self.layer_count
|
||||
)
|
||||
params = torch.load(self.encoder_params_path)
|
||||
self.encoder.load_state_dict(params, strict=False)
|
||||
|
@ -68,15 +76,15 @@ class MinDalleTorch:
|
|||
def init_decoder(self):
|
||||
print("initializing DalleBartDecoderTorch")
|
||||
self.decoder = DalleBartDecoderTorch(
|
||||
sample_token_count = self.token_count,
|
||||
image_token_count = 256,
|
||||
image_vocab_count = 16415 if self.is_mega else 16384,
|
||||
attention_head_count = 32 if self.is_mega else 16,
|
||||
embed_count = 2048 if self.is_mega else 1024,
|
||||
glu_embed_count = 4096 if self.is_mega else 2730,
|
||||
layer_count = 24 if self.is_mega else 12,
|
||||
start_token = 16415 if self.is_mega else 16384,
|
||||
batch_count = 2
|
||||
sample_token_count = self.sample_token_count,
|
||||
image_token_count = self.image_token_count,
|
||||
image_vocab_count = self.image_vocab_count,
|
||||
attention_head_count = self.attention_head_count,
|
||||
embed_count = self.embed_count,
|
||||
glu_embed_count = self.glu_embed_count,
|
||||
layer_count = self.layer_count,
|
||||
start_token = self.image_vocab_count,
|
||||
batch_count = self.batch_count
|
||||
)
|
||||
params = torch.load(self.decoder_params_path)
|
||||
self.decoder.load_state_dict(params, strict=False)
|
||||
|
@ -93,18 +101,14 @@ class MinDalleTorch:
|
|||
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
|
||||
|
||||
|
||||
def tokenize_text(self, text: str) -> numpy.ndarray:
|
||||
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
||||
print("tokenizing text")
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
print("text tokens", tokens)
|
||||
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
|
||||
text_tokens[0, :2] = [tokens[0], tokens[-1]]
|
||||
text_tokens[1, :len(tokens)] = tokens
|
||||
return text_tokens
|
||||
|
||||
|
||||
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
||||
text_tokens = self.tokenize_text(text)
|
||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user