From 18c72ed34ddfa2b6bf0747cbf0bc9bdbd034803e Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Fri, 1 Jul 2022 15:53:39 -0400 Subject: [PATCH] simplified MinDalleTorch --- min_dalle/min_dalle_torch.py | 76 +++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index f6c901c..76375b6 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -19,45 +19,53 @@ class MinDalleTorch: self, is_mega: bool, is_reusable: bool = True, - token_count: int = 256 + sample_token_count: int = 256 ): print("initializing MinDalleTorch") - self.is_mega = is_mega + self.is_reusable = is_reusable + self.sample_token_count = sample_token_count + self.batch_count = 2 + self.text_token_count = 64 + self.image_token_count = 256 + self.layer_count = 24 if is_mega else 12 + self.attention_head_count = 32 if is_mega else 16 + self.embed_count = 2048 if is_mega else 1024 + self.glu_embed_count = 4096 if is_mega else 2730 + self.text_vocab_count = 50272 if is_mega else 50264 + self.image_vocab_count = 16415 if is_mega else 16384 + model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') self.model_path = os.path.join('pretrained', model_name) - - print("reading files from {}".format(self.model_path)) - vocab_path = os.path.join(self.model_path, 'vocab.json') - merges_path = os.path.join(self.model_path, 'merges.txt') - - with open(vocab_path, 'r', encoding='utf8') as f: - vocab = json.load(f) - with open(merges_path, 'r', encoding='utf8') as f: - merges = f.read().split("\n")[1:-1] - - self.tokenizer = TextTokenizer(vocab, merges) - self.is_reusable = is_reusable - self.token_count = token_count - self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt') self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt') self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt') + self.init_tokenizer() if is_reusable: self.init_encoder() self.init_decoder() self.init_detokenizer() + def init_tokenizer(self): + print("reading files from {}".format(self.model_path)) + vocab_path = os.path.join(self.model_path, 'vocab.json') + merges_path = os.path.join(self.model_path, 'merges.txt') + with open(vocab_path, 'r', encoding='utf8') as f: + vocab = json.load(f) + with open(merges_path, 'r', encoding='utf8') as f: + merges = f.read().split("\n")[1:-1] + self.tokenizer = TextTokenizer(vocab, merges) + def init_encoder(self): print("initializing DalleBartEncoderTorch") self.encoder = DalleBartEncoderTorch( - attention_head_count = 32 if self.is_mega else 16, - embed_count = 2048 if self.is_mega else 1024, - glu_embed_count = 4096 if self.is_mega else 2730, - text_token_count = 64, - text_vocab_count = 50272 if self.is_mega else 50264, - layer_count = 24 if self.is_mega else 12 + attention_head_count = self.attention_head_count, + embed_count = self.embed_count, + glu_embed_count = self.glu_embed_count, + text_token_count = self.text_token_count, + text_vocab_count = self.text_vocab_count, + layer_count = self.layer_count ) params = torch.load(self.encoder_params_path) self.encoder.load_state_dict(params, strict=False) @@ -68,15 +76,15 @@ class MinDalleTorch: def init_decoder(self): print("initializing DalleBartDecoderTorch") self.decoder = DalleBartDecoderTorch( - sample_token_count = self.token_count, - image_token_count = 256, - image_vocab_count = 16415 if self.is_mega else 16384, - attention_head_count = 32 if self.is_mega else 16, - embed_count = 2048 if self.is_mega else 1024, - glu_embed_count = 4096 if self.is_mega else 2730, - layer_count = 24 if self.is_mega else 12, - start_token = 16415 if self.is_mega else 16384, - batch_count = 2 + sample_token_count = self.sample_token_count, + image_token_count = self.image_token_count, + image_vocab_count = self.image_vocab_count, + attention_head_count = self.attention_head_count, + embed_count = self.embed_count, + glu_embed_count = self.glu_embed_count, + layer_count = self.layer_count, + start_token = self.image_vocab_count, + batch_count = self.batch_count ) params = torch.load(self.decoder_params_path) self.decoder.load_state_dict(params, strict=False) @@ -93,18 +101,14 @@ class MinDalleTorch: if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() - def tokenize_text(self, text: str) -> numpy.ndarray: + def generate_image_tokens(self, text: str, seed: int) -> LongTensor: print("tokenizing text") tokens = self.tokenizer.tokenize(text) print("text tokens", tokens) text_tokens = numpy.ones((2, 64), dtype=numpy.int32) text_tokens[0, :2] = [tokens[0], tokens[-1]] text_tokens[1, :len(tokens)] = tokens - return text_tokens - - def generate_image_tokens(self, text: str, seed: int) -> LongTensor: - text_tokens = self.tokenize_text(text) text_tokens = torch.tensor(text_tokens).to(torch.long) if torch.cuda.is_available(): text_tokens = text_tokens.cuda()