From cf9656baa2e46ebb2d2a561769b2cce92aa546b7 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Fri, 1 Jul 2022 20:17:20 -0400 Subject: [PATCH] added is_verbose flag --- image_from_text.py | 3 ++- min_dalle/min_dalle.py | 34 ++++++++++++++++++---------------- min_dalle/text_tokenizer.py | 5 +++-- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/image_from_text.py b/image_from_text.py index e35341c..f85ac1c 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -44,7 +44,8 @@ def generate_image( is_mega=is_mega, models_root='pretrained', is_reusable=False, - sample_token_count=token_count + sample_token_count=token_count, + is_verbose=True ) if token_count < 256: diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index 2e92cc6..37d0c6b 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -21,11 +21,12 @@ class MinDalle: is_mega: bool, is_reusable: bool = True, models_root: str = 'pretrained', - sample_token_count: int = 256 + sample_token_count: int = 256, + is_verbose = True ): - print("initializing MinDalleTorch") self.is_mega = is_mega self.is_reusable = is_reusable + self.is_verbose = is_verbose self.sample_token_count = sample_token_count self.batch_count = 2 self.text_token_count = 64 @@ -37,6 +38,7 @@ class MinDalle: self.text_vocab_count = 50272 if is_mega else 50264 self.image_vocab_count = 16415 if is_mega else 16384 + if self.is_verbose: print("initializing MinDalle") model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') dalle_path = os.path.join(models_root, model_name) vqgan_path = os.path.join(models_root, 'vqgan') @@ -56,7 +58,7 @@ class MinDalle: def download_tokenizer(self): - print("downloading tokenizer params") + if self.is_verbose: print("downloading tokenizer params") suffix = '' if self.is_mega else '_mini' vocab = requests.get(MIN_DALLE_REPO + 'vocab{}.json'.format(suffix)) merges = requests.get(MIN_DALLE_REPO + 'merges{}.txt'.format(suffix)) @@ -65,21 +67,21 @@ class MinDalle: def download_encoder(self): - print("downloading encoder params") + if self.is_verbose: print("downloading encoder params") suffix = '' if self.is_mega else '_mini' params = requests.get(MIN_DALLE_REPO + 'encoder{}.pt'.format(suffix)) with open(self.encoder_params_path, 'wb') as f: f.write(params.content) def download_decoder(self): - print("downloading decoder params") + if self.is_verbose: print("downloading decoder params") suffix = '' if self.is_mega else '_mini' params = requests.get(MIN_DALLE_REPO + 'decoder{}.pt'.format(suffix)) with open(self.decoder_params_path, 'wb') as f: f.write(params.content) def download_detokenizer(self): - print("downloading detokenizer params") + if self.is_verbose: print("downloading detokenizer params") params = requests.get(MIN_DALLE_REPO + 'detoker.pt') with open(self.detoker_params_path, 'wb') as f: f.write(params.content) @@ -88,18 +90,18 @@ class MinDalle: is_downloaded = os.path.exists(self.vocab_path) is_downloaded &= os.path.exists(self.merges_path) if not is_downloaded: self.download_tokenizer() - print("intializing TextTokenizer") + if self.is_verbose: print("intializing TextTokenizer") with open(self.vocab_path, 'r', encoding='utf8') as f: vocab = json.load(f) with open(self.merges_path, 'r', encoding='utf8') as f: merges = f.read().split("\n")[1:-1] - self.tokenizer = TextTokenizer(vocab, merges) + self.tokenizer = TextTokenizer(vocab, merges, is_verbose=self.is_verbose) def init_encoder(self): is_downloaded = os.path.exists(self.encoder_params_path) if not is_downloaded: self.download_encoder() - print("initializing DalleBartEncoderTorch") + if self.is_verbose: print("initializing DalleBartEncoder") self.encoder = DalleBartEncoder( attention_head_count = self.attention_head_count, embed_count = self.embed_count, @@ -117,7 +119,7 @@ class MinDalle: def init_decoder(self): is_downloaded = os.path.exists(self.decoder_params_path) if not is_downloaded: self.download_decoder() - print("initializing DalleBartDecoderTorch") + if self.is_verbose: print("initializing DalleBartDecoder") self.decoder = DalleBartDecoder( sample_token_count = self.sample_token_count, image_token_count = self.image_token_count, @@ -138,7 +140,7 @@ class MinDalle: def init_detokenizer(self): is_downloaded = os.path.exists(self.detoker_params_path) if not is_downloaded: self.download_detokenizer() - print("initializing VQGanDetokenizer") + if self.is_verbose: print("initializing VQGanDetokenizer") self.detokenizer = VQGanDetokenizer() params = torch.load(self.detoker_params_path) self.detokenizer.load_state_dict(params) @@ -147,9 +149,9 @@ class MinDalle: def generate_image_tokens(self, text: str, seed: int) -> LongTensor: - print("tokenizing text") + if self.is_verbose: print("tokenizing text") tokens = self.tokenizer.tokenize(text) - print("text tokens", tokens) + if self.is_verbose: print("text tokens", tokens) text_tokens = numpy.ones((2, 64), dtype=numpy.int32) text_tokens[0, :2] = [tokens[0], tokens[-1]] text_tokens[1, :len(tokens)] = tokens @@ -158,12 +160,12 @@ class MinDalle: if torch.cuda.is_available(): text_tokens = text_tokens.cuda() if not self.is_reusable: self.init_encoder() - print("encoding text tokens") + if self.is_verbose: print("encoding text tokens") encoder_state = self.encoder.forward(text_tokens) if not self.is_reusable: del self.encoder if not self.is_reusable: self.init_decoder() - print("sampling image tokens") + if self.is_verbose: print("sampling image tokens") if seed < 0: seed = random.randint(0, 2 ** 31) torch.manual_seed(seed) image_tokens = self.decoder.forward(text_tokens, encoder_state) @@ -174,7 +176,7 @@ class MinDalle: def generate_image(self, text: str, seed: int) -> Image.Image: image_tokens = self.generate_image_tokens(text, seed) if not self.is_reusable: self.init_detokenizer() - print("detokenizing image") + if self.is_verbose: print("detokenizing image") image = self.detokenizer.forward(image_tokens).to(torch.uint8) if not self.is_reusable: del self.detokenizer image = Image.fromarray(image.to('cpu').detach().numpy()) diff --git a/min_dalle/text_tokenizer.py b/min_dalle/text_tokenizer.py index 1d06349..01d2111 100644 --- a/min_dalle/text_tokenizer.py +++ b/min_dalle/text_tokenizer.py @@ -3,7 +3,8 @@ from typing import List, Tuple class TextTokenizer: - def __init__(self, vocab: dict, merges: List[str]): + def __init__(self, vocab: dict, merges: List[str], is_verbose: bool = True): + self.is_verbose = is_verbose self.token_from_subword = vocab pairs = [tuple(pair.split()) for pair in merges] self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) @@ -36,5 +37,5 @@ class TextTokenizer: (subwords[i + 2:] if i + 2 < len(subwords) else []) ) - print(subwords) + if self.is_verbose: print(subwords) return subwords \ No newline at end of file