faster decoder self attention

This commit is contained in:
Brett Kuprel
2022-07-04 08:05:55 -04:00
parent a79d30f718
commit 377d15cb16
4 changed files with 21 additions and 23 deletions

View File

@@ -165,6 +165,7 @@ class MinDalle:
if self.is_verbose: print("encoding text tokens")
encoder_state = self.encoder.forward(text_tokens)
if not self.is_reusable: del self.encoder
if torch.cuda.is_available(): torch.cuda.empty_cache()
if not self.is_reusable: self.init_decoder()
if self.is_verbose: print("sampling image tokens")
@@ -175,7 +176,6 @@ class MinDalle:
encoder_state
)
if not self.is_reusable: del self.decoder
if torch.cuda.is_available(): torch.cuda.empty_cache()
return image_tokens
@@ -187,6 +187,7 @@ class MinDalle:
) -> Image.Image:
image_count = grid_size ** 2
image_tokens = self.generate_image_tokens(text, seed, image_count)
if torch.cuda.is_available(): torch.cuda.empty_cache()
if not self.is_reusable: self.init_detokenizer()
if self.is_verbose: print("detokenizing image")
images = self.detokenizer.forward(image_tokens).to(torch.uint8)
@@ -194,4 +195,5 @@ class MinDalle:
images = images.reshape([grid_size] * 2 + list(images.shape[1:]))
image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2)
image = Image.fromarray(image.to('cpu').detach().numpy())
if torch.cuda.is_available(): torch.cuda.empty_cache()
return image

View File

@@ -20,25 +20,28 @@ class DecoderCrossAttention(AttentionBase):
class DecoderSelfAttention(AttentionBase):
def __init__(self, head_count: int, embed_count: int):
super().__init__(head_count, embed_count)
token_indices = torch.arange(256)
if torch.cuda.is_available(): token_indices = token_indices.cuda()
self.token_indices = token_indices
def forward(
self,
decoder_state: FloatTensor,
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_mask: BoolTensor
token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]:
keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state)
attention_state = torch.where(
token_mask[None, :, None],
torch.cat([keys, values]),
attention_state
)
attn_mask = self.token_indices < token_index + 1
attn_mask = attn_mask[None][[0] * decoder_state.shape[0]]
attention_state[:, token_index] = torch.cat([keys, values])
batch_count = decoder_state.shape[0]
keys = attention_state[:batch_count]
values = attention_state[batch_count:]
decoder_state = super().forward(keys, values, queries, attention_mask)
decoder_state = super().forward(keys, values, queries, attn_mask)
return decoder_state, attention_state
@@ -60,9 +63,6 @@ class DecoderLayer(nn.Module):
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.glu = GLU(embed_count, glu_embed_count)
self.token_indices = torch.arange(self.image_token_count)
if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda()
def forward(
self,
@@ -75,14 +75,10 @@ class DecoderLayer(nn.Module):
# Self Attention
residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
self_attn_mask = self.token_indices < token_index + 1
self_attn_mask = self_attn_mask[None][[0] * decoder_state.shape[0]]
token_mask = self.token_indices == token_index
decoder_state, attention_state = self.self_attn.forward(
decoder_state,
attention_state,
self_attn_mask,
token_mask
token_index
)
decoder_state = self.self_attn_layer_norm.forward(decoder_state)
decoder_state = residual + decoder_state