|
|
|
@ -19,45 +19,53 @@ class MinDalleTorch: |
|
|
|
|
self, |
|
|
|
|
is_mega: bool, |
|
|
|
|
is_reusable: bool = True, |
|
|
|
|
token_count: int = 256 |
|
|
|
|
sample_token_count: int = 256 |
|
|
|
|
): |
|
|
|
|
print("initializing MinDalleTorch") |
|
|
|
|
self.is_mega = is_mega |
|
|
|
|
self.is_reusable = is_reusable |
|
|
|
|
self.sample_token_count = sample_token_count |
|
|
|
|
self.batch_count = 2 |
|
|
|
|
self.text_token_count = 64 |
|
|
|
|
self.image_token_count = 256 |
|
|
|
|
self.layer_count = 24 if is_mega else 12 |
|
|
|
|
self.attention_head_count = 32 if is_mega else 16 |
|
|
|
|
self.embed_count = 2048 if is_mega else 1024 |
|
|
|
|
self.glu_embed_count = 4096 if is_mega else 2730 |
|
|
|
|
self.text_vocab_count = 50272 if is_mega else 50264 |
|
|
|
|
self.image_vocab_count = 16415 if is_mega else 16384 |
|
|
|
|
|
|
|
|
|
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') |
|
|
|
|
self.model_path = os.path.join('pretrained', model_name) |
|
|
|
|
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt') |
|
|
|
|
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt') |
|
|
|
|
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt') |
|
|
|
|
|
|
|
|
|
self.init_tokenizer() |
|
|
|
|
if is_reusable: |
|
|
|
|
self.init_encoder() |
|
|
|
|
self.init_decoder() |
|
|
|
|
self.init_detokenizer() |
|
|
|
|
|
|
|
|
|
def init_tokenizer(self): |
|
|
|
|
print("reading files from {}".format(self.model_path)) |
|
|
|
|
vocab_path = os.path.join(self.model_path, 'vocab.json') |
|
|
|
|
merges_path = os.path.join(self.model_path, 'merges.txt') |
|
|
|
|
|
|
|
|
|
with open(vocab_path, 'r', encoding='utf8') as f: |
|
|
|
|
vocab = json.load(f) |
|
|
|
|
with open(merges_path, 'r', encoding='utf8') as f: |
|
|
|
|
merges = f.read().split("\n")[1:-1] |
|
|
|
|
|
|
|
|
|
self.tokenizer = TextTokenizer(vocab, merges) |
|
|
|
|
self.is_reusable = is_reusable |
|
|
|
|
self.token_count = token_count |
|
|
|
|
|
|
|
|
|
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt') |
|
|
|
|
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt') |
|
|
|
|
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt') |
|
|
|
|
|
|
|
|
|
if is_reusable: |
|
|
|
|
self.init_encoder() |
|
|
|
|
self.init_decoder() |
|
|
|
|
self.init_detokenizer() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_encoder(self): |
|
|
|
|
print("initializing DalleBartEncoderTorch") |
|
|
|
|
self.encoder = DalleBartEncoderTorch( |
|
|
|
|
attention_head_count = 32 if self.is_mega else 16, |
|
|
|
|
embed_count = 2048 if self.is_mega else 1024, |
|
|
|
|
glu_embed_count = 4096 if self.is_mega else 2730, |
|
|
|
|
text_token_count = 64, |
|
|
|
|
text_vocab_count = 50272 if self.is_mega else 50264, |
|
|
|
|
layer_count = 24 if self.is_mega else 12 |
|
|
|
|
attention_head_count = self.attention_head_count, |
|
|
|
|
embed_count = self.embed_count, |
|
|
|
|
glu_embed_count = self.glu_embed_count, |
|
|
|
|
text_token_count = self.text_token_count, |
|
|
|
|
text_vocab_count = self.text_vocab_count, |
|
|
|
|
layer_count = self.layer_count |
|
|
|
|
) |
|
|
|
|
params = torch.load(self.encoder_params_path) |
|
|
|
|
self.encoder.load_state_dict(params, strict=False) |
|
|
|
@ -68,15 +76,15 @@ class MinDalleTorch: |
|
|
|
|
def init_decoder(self): |
|
|
|
|
print("initializing DalleBartDecoderTorch") |
|
|
|
|
self.decoder = DalleBartDecoderTorch( |
|
|
|
|
sample_token_count = self.token_count, |
|
|
|
|
image_token_count = 256, |
|
|
|
|
image_vocab_count = 16415 if self.is_mega else 16384, |
|
|
|
|
attention_head_count = 32 if self.is_mega else 16, |
|
|
|
|
embed_count = 2048 if self.is_mega else 1024, |
|
|
|
|
glu_embed_count = 4096 if self.is_mega else 2730, |
|
|
|
|
layer_count = 24 if self.is_mega else 12, |
|
|
|
|
start_token = 16415 if self.is_mega else 16384, |
|
|
|
|
batch_count = 2 |
|
|
|
|
sample_token_count = self.sample_token_count, |
|
|
|
|
image_token_count = self.image_token_count, |
|
|
|
|
image_vocab_count = self.image_vocab_count, |
|
|
|
|
attention_head_count = self.attention_head_count, |
|
|
|
|
embed_count = self.embed_count, |
|
|
|
|
glu_embed_count = self.glu_embed_count, |
|
|
|
|
layer_count = self.layer_count, |
|
|
|
|
start_token = self.image_vocab_count, |
|
|
|
|
batch_count = self.batch_count |
|
|
|
|
) |
|
|
|
|
params = torch.load(self.decoder_params_path) |
|
|
|
|
self.decoder.load_state_dict(params, strict=False) |
|
|
|
@ -93,18 +101,14 @@ class MinDalleTorch: |
|
|
|
|
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize_text(self, text: str) -> numpy.ndarray: |
|
|
|
|
def generate_image_tokens(self, text: str, seed: int) -> LongTensor: |
|
|
|
|
print("tokenizing text") |
|
|
|
|
tokens = self.tokenizer.tokenize(text) |
|
|
|
|
print("text tokens", tokens) |
|
|
|
|
text_tokens = numpy.ones((2, 64), dtype=numpy.int32) |
|
|
|
|
text_tokens[0, :2] = [tokens[0], tokens[-1]] |
|
|
|
|
text_tokens[1, :len(tokens)] = tokens |
|
|
|
|
return text_tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image_tokens(self, text: str, seed: int) -> LongTensor: |
|
|
|
|
text_tokens = self.tokenize_text(text) |
|
|
|
|
text_tokens = torch.tensor(text_tokens).to(torch.long) |
|
|
|
|
if torch.cuda.is_available(): text_tokens = text_tokens.cuda() |
|
|
|
|
|
|
|
|
|