From deefd24919f5f5b5b96127b749675f70cedb9435 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Mon, 4 Jul 2022 09:58:00 -0400 Subject: [PATCH] decode_row --- image_from_text.py | 17 ++++++--- min_dalle/min_dalle.py | 10 ++--- min_dalle/models/dalle_bart_decoder.py | 53 +++++++++++++++++--------- 3 files changed, 52 insertions(+), 28 deletions(-) diff --git a/image_from_text.py b/image_from_text.py index 061fd05..77429f8 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -14,7 +14,7 @@ parser.add_argument('--seed', type=int, default=-1) parser.add_argument('--grid-size', type=int, default=1) parser.add_argument('--image-path', type=str, default='generated') parser.add_argument('--models-root', type=str, default='pretrained') -parser.add_argument('--token-count', type=int, default=256) # for debugging +parser.add_argument('--row-count', type=int, default=16) # for debugging def ascii_from_image(image: Image.Image, size: int) -> str: @@ -42,18 +42,23 @@ def generate_image( grid_size: int, image_path: str, models_root: str, - token_count: int + row_count: int ): model = MinDalle( is_mega=is_mega, models_root=models_root, is_reusable=False, - sample_token_count=token_count, is_verbose=True ) - if token_count < 256: - image_tokens = model.generate_image_tokens(text, seed, grid_size ** 2) + if row_count < 16: + token_count = 16 * row_count + image_tokens = model.generate_image_tokens( + text, + seed, + grid_size ** 2, + row_count + ) image_tokens = image_tokens[:, :token_count].to('cpu').detach().numpy() print('image tokens', image_tokens) else: @@ -72,5 +77,5 @@ if __name__ == '__main__': grid_size=args.grid_size, image_path=args.image_path, models_root=args.models_root, - token_count=args.token_count + row_count=args.row_count ) \ No newline at end of file diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index c637a43..f156f5c 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -20,13 +20,11 @@ class MinDalle: is_mega: bool, is_reusable: bool = True, models_root: str = 'pretrained', - sample_token_count: int = 256, is_verbose = True ): self.is_mega = is_mega self.is_reusable = is_reusable self.is_verbose = is_verbose - self.sample_token_count = sample_token_count self.text_token_count = 64 self.image_token_count = 256 self.layer_count = 24 if is_mega else 12 @@ -119,7 +117,6 @@ class MinDalle: if not is_downloaded: self.download_decoder() if self.is_verbose: print("initializing DalleBartDecoder") self.decoder = DalleBartDecoder( - sample_token_count = self.sample_token_count, image_token_count = self.image_token_count, image_vocab_count = self.image_vocab_count, attention_head_count = self.attention_head_count, @@ -149,7 +146,8 @@ class MinDalle: self, text: str, seed: int, - image_count: int + image_count: int, + row_count: int ) -> LongTensor: if self.is_verbose: print("tokenizing text") tokens = self.tokenizer.tokenize(text) @@ -172,6 +170,7 @@ class MinDalle: if seed > 0: torch.manual_seed(seed) image_tokens = self.decoder.forward( image_count, + row_count, text_tokens, encoder_state ) @@ -186,7 +185,8 @@ class MinDalle: grid_size: int = 1 ) -> Image.Image: image_count = grid_size ** 2 - image_tokens = self.generate_image_tokens(text, seed, image_count) + row_count = 16 + image_tokens = self.generate_image_tokens(text, seed, image_count, row_count) if torch.cuda.is_available(): torch.cuda.empty_cache() if not self.is_reusable: self.init_detokenizer() if self.is_verbose: print("detokenizing image") diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index b45e6cc..cf82704 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -111,7 +111,6 @@ class DalleBartDecoder(nn.Module): self, image_vocab_count: int, image_token_count: int, - sample_token_count: int, embed_count: int, attention_head_count: int, glu_embed_count: int, @@ -121,7 +120,6 @@ class DalleBartDecoder(nn.Module): super().__init__() self.layer_count = layer_count self.embed_count = embed_count - self.sample_token_count = sample_token_count self.condition_factor = 10.0 self.image_token_count = image_token_count self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) @@ -139,7 +137,7 @@ class DalleBartDecoder(nn.Module): self.final_ln = nn.LayerNorm(embed_count) self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False) self.zero_prob = torch.zeros([1]) - self.token_indices = torch.arange(self.sample_token_count) + self.token_indices = torch.arange(self.image_token_count) self.start_token = torch.tensor([start_token]).to(torch.long) if torch.cuda.is_available(): self.zero_prob = self.zero_prob.cuda() @@ -185,11 +183,35 @@ class DalleBartDecoder(nn.Module): torch.exp(logits - top_logits[:, [0]]) ) return probs, attention_state + + + def decode_row( + self, + row_index: int, + attention_mask: BoolTensor, + encoder_state: FloatTensor, + attention_state: FloatTensor, + image_tokens_sequence: LongTensor + ) -> Tuple[FloatTensor, LongTensor]: + for col_index in range(16): + i = 16 * row_index + col_index + probs, attention_state = self.decode_step( + attention_mask = attention_mask, + encoder_state = encoder_state, + attention_state = attention_state, + prev_tokens = image_tokens_sequence[:, i], + token_index = self.token_indices[[i]] + ) + + image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0] + + return attention_state, image_tokens_sequence def forward( self, image_count: int, + row_count: int, text_tokens: LongTensor, encoder_state: FloatTensor ) -> LongTensor: @@ -206,7 +228,7 @@ class DalleBartDecoder(nn.Module): ) attention_state = torch.zeros(attention_state_shape) image_tokens_sequence = torch.full( - (image_count, self.image_token_count), + (image_count, self.image_token_count + 1), 6965, # black token dtype=torch.long ) @@ -214,18 +236,15 @@ class DalleBartDecoder(nn.Module): attention_state = attention_state.cuda() image_tokens_sequence = image_tokens_sequence.cuda() - image_tokens = self.start_token[[0] * image_count] - - for i in range(self.sample_token_count): - probs, attention_state = self.decode_step( - attention_mask = attention_mask, - encoder_state = encoder_state, - attention_state = attention_state, - prev_tokens = image_tokens, - token_index = self.token_indices[[i]] - ) + image_tokens_sequence[:, 0] = self.start_token[0] - image_tokens = torch.multinomial(probs, 1)[:, 0] - image_tokens_sequence[:, i] = image_tokens + for row_index in range(row_count): + attention_state, image_tokens_sequence = self.decode_row( + row_index, + attention_mask, + encoder_state, + attention_state, + image_tokens_sequence + ) - return image_tokens_sequence \ No newline at end of file + return image_tokens_sequence[:, 1:] \ No newline at end of file