From 377d15cb16727ed45b4b23fdad6664f87857e4e5 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Mon, 4 Jul 2022 08:05:55 -0400 Subject: [PATCH] faster decoder self attention --- min_dalle.ipynb | 2 +- min_dalle/min_dalle.py | 4 +++- min_dalle/models/dalle_bart_decoder.py | 28 +++++++++++--------------- replicate/predict.py | 10 ++++----- 4 files changed, 21 insertions(+), 23 deletions(-) diff --git a/min_dalle.ipynb b/min_dalle.ipynb index 6fc32ae..0335319 100644 --- a/min_dalle.ipynb +++ b/min_dalle.ipynb @@ -178,8 +178,8 @@ "%%time\n", "\n", "text = \"Dali painting of WALL·E\" #@param {type:\"string\"}\n", - "seed = 0 #@param {type:\"integer\"}\n", "grid_size = 2 #@param {type:\"integer\"}\n", + "seed = -1 #@param {type:\"integer\"}\n", "\n", "display(model.generate_image(text, seed, grid_size))" ] diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index 35509db..c637a43 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -165,6 +165,7 @@ class MinDalle: if self.is_verbose: print("encoding text tokens") encoder_state = self.encoder.forward(text_tokens) if not self.is_reusable: del self.encoder + if torch.cuda.is_available(): torch.cuda.empty_cache() if not self.is_reusable: self.init_decoder() if self.is_verbose: print("sampling image tokens") @@ -175,7 +176,6 @@ class MinDalle: encoder_state ) if not self.is_reusable: del self.decoder - if torch.cuda.is_available(): torch.cuda.empty_cache() return image_tokens @@ -187,6 +187,7 @@ class MinDalle: ) -> Image.Image: image_count = grid_size ** 2 image_tokens = self.generate_image_tokens(text, seed, image_count) + if torch.cuda.is_available(): torch.cuda.empty_cache() if not self.is_reusable: self.init_detokenizer() if self.is_verbose: print("detokenizing image") images = self.detokenizer.forward(image_tokens).to(torch.uint8) @@ -194,4 +195,5 @@ class MinDalle: images = images.reshape([grid_size] * 2 + list(images.shape[1:])) image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2) image = Image.fromarray(image.to('cpu').detach().numpy()) + if torch.cuda.is_available(): torch.cuda.empty_cache() return image \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index ce18c93..50afe6d 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -20,25 +20,28 @@ class DecoderCrossAttention(AttentionBase): class DecoderSelfAttention(AttentionBase): + def __init__(self, head_count: int, embed_count: int): + super().__init__(head_count, embed_count) + token_indices = torch.arange(256) + if torch.cuda.is_available(): token_indices = token_indices.cuda() + self.token_indices = token_indices + def forward( self, decoder_state: FloatTensor, attention_state: FloatTensor, - attention_mask: BoolTensor, - token_mask: BoolTensor + token_index: LongTensor ) -> Tuple[FloatTensor, FloatTensor]: keys = self.k_proj.forward(decoder_state) values = self.v_proj.forward(decoder_state) queries = self.q_proj.forward(decoder_state) - attention_state = torch.where( - token_mask[None, :, None], - torch.cat([keys, values]), - attention_state - ) + attn_mask = self.token_indices < token_index + 1 + attn_mask = attn_mask[None][[0] * decoder_state.shape[0]] + attention_state[:, token_index] = torch.cat([keys, values]) batch_count = decoder_state.shape[0] keys = attention_state[:batch_count] values = attention_state[batch_count:] - decoder_state = super().forward(keys, values, queries, attention_mask) + decoder_state = super().forward(keys, values, queries, attn_mask) return decoder_state, attention_state @@ -60,9 +63,6 @@ class DecoderLayer(nn.Module): self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) self.glu = GLU(embed_count, glu_embed_count) - self.token_indices = torch.arange(self.image_token_count) - if torch.cuda.is_available(): - self.token_indices = self.token_indices.cuda() def forward( self, @@ -75,14 +75,10 @@ class DecoderLayer(nn.Module): # Self Attention residual = decoder_state decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) - self_attn_mask = self.token_indices < token_index + 1 - self_attn_mask = self_attn_mask[None][[0] * decoder_state.shape[0]] - token_mask = self.token_indices == token_index decoder_state, attention_state = self.self_attn.forward( decoder_state, attention_state, - self_attn_mask, - token_mask + token_index ) decoder_state = self.self_attn_layer_norm.forward(decoder_state) decoder_state = residual + decoder_state diff --git a/replicate/predict.py b/replicate/predict.py index fbf7c98..69c51ea 100644 --- a/replicate/predict.py +++ b/replicate/predict.py @@ -13,16 +13,16 @@ class Predictor(BasePredictor): description='Text', default='Dali painting of WALL·E' ), - seed: int = Input( - description='Set the seed to a positive number for reproducible results', - default=-1 - ), grid_size: int = Input( description='Size of the image grid', ge=1, le=4, default=4 - ) + ), + seed: int = Input( + description='Set the seed to a positive number for reproducible results', + default=-1 + ), ) -> Path: image = self.model.generate_image(text, seed, grid_size=grid_size) out_path = Path(tempfile.mkdtemp()) / 'output.jpg'