inplace attention state, faster and less memory

This commit is contained in:
Brett Kuprel
2022-07-04 09:14:37 -04:00
parent aca617dc64
commit 6f617fe98f
4 changed files with 25 additions and 15 deletions

View File

@@ -20,9 +20,9 @@ class DecoderCrossAttention(AttentionBase):
class DecoderSelfAttention(AttentionBase):
def __init__(self, head_count: int, embed_count: int):
def __init__(self, head_count: int, embed_count: int, token_count: int):
super().__init__(head_count, embed_count)
token_indices = torch.arange(256)
token_indices = torch.arange(token_count)
if torch.cuda.is_available(): token_indices = token_indices.cuda()
self.token_indices = token_indices
@@ -56,7 +56,11 @@ class DecoderLayer(nn.Module):
super().__init__()
self.image_token_count = image_token_count
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
self.self_attn = DecoderSelfAttention(head_count, embed_count)
self.self_attn = DecoderSelfAttention(
head_count,
embed_count,
image_token_count
)
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
@@ -150,7 +154,7 @@ class DalleBartDecoder(nn.Module):
attention_state: FloatTensor,
prev_tokens: LongTensor,
token_index: LongTensor
) -> Tuple[LongTensor, FloatTensor]:
) -> Tuple[FloatTensor, FloatTensor]:
image_count = encoder_state.shape[0] // 2
token_index_batched = token_index[[0] * image_count * 2]
prev_tokens = prev_tokens[list(range(image_count)) * 2]
@@ -158,16 +162,14 @@ class DalleBartDecoder(nn.Module):
decoder_state += self.embed_positions.forward(token_index_batched)
decoder_state = self.layernorm_embedding.forward(decoder_state)
decoder_state = decoder_state[:, None]
attention_states_new = []
for i in range(self.layer_count):
decoder_state, attention_state_layer = self.layers[i].forward(
decoder_state, attention_state[i] = self.layers[i].forward(
decoder_state,
encoder_state,
attention_state[i],
attention_mask,
token_index
)
attention_states_new.append(attention_state_layer)
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
a = self.condition_factor
@@ -182,7 +184,7 @@ class DalleBartDecoder(nn.Module):
self.zero_prob,
torch.exp(logits - top_logits[:, [0]])
)
return probs, torch.stack(attention_states_new)
return probs, attention_state
def forward(
@@ -203,10 +205,17 @@ class DalleBartDecoder(nn.Module):
self.embed_count
)
attention_state = torch.zeros(attention_state_shape)
if torch.cuda.is_available(): attention_state = attention_state.cuda()
image_tokens_sequence = torch.full(
(image_count, self.image_token_count),
6965, # black token
dtype=torch.long
)
if torch.cuda.is_available():
attention_state = attention_state.cuda()
image_tokens_sequence = image_tokens_sequence.cuda()
image_tokens = self.start_token[[0] * image_count]
image_tokens_sequence: List[LongTensor] = []
for i in range(self.sample_token_count):
probs, attention_state = self.decode_step(
attention_mask = attention_mask,
@@ -217,6 +226,6 @@ class DalleBartDecoder(nn.Module):
)
image_tokens = torch.multinomial(probs, 1)[:, 0]
image_tokens_sequence += [image_tokens]
image_tokens_sequence[:, i] = image_tokens
return torch.stack(image_tokens_sequence).T
return image_tokens_sequence