diff --git a/min_dalle/models/dalle_bart_decoder_torch.py b/min_dalle/models/dalle_bart_decoder_torch.py index bce3bff..9957f2b 100644 --- a/min_dalle/models/dalle_bart_decoder_torch.py +++ b/min_dalle/models/dalle_bart_decoder_torch.py @@ -16,12 +16,6 @@ class DecoderCrossAttentionTorch(AttentionTorch): keys = self.k_proj.forward(encoder_state) values = self.v_proj.forward(encoder_state) queries = self.q_proj.forward(decoder_state) - query_shape = queries.shape[:2] + (self.head_count, -1) - key_value_shape = keys.shape[:2] + (self.head_count, -1) - keys = keys.reshape(key_value_shape) - values = values.reshape(key_value_shape) - queries = queries.reshape(query_shape) - queries /= queries.shape[-1] ** 0.5 return super().forward(keys, values, queries, attention_mask) @@ -34,16 +28,14 @@ class DecoderSelfAttentionTorch(AttentionTorch): token_mask: BoolTensor ) -> Tuple[FloatTensor, FloatTensor]: batch_count = decoder_state.shape[0] - shape = (batch_count, 1) + keys_values.shape[2:] - keys = self.k_proj.forward(decoder_state).view(shape) - values = self.v_proj.forward(decoder_state).view(shape) + keys = self.k_proj.forward(decoder_state) + values = self.v_proj.forward(decoder_state) + queries = self.q_proj.forward(decoder_state) keys_values = torch.where( - token_mask[None, :, None, None], + token_mask[None, :, None], torch.cat([keys, values]), keys_values ) - queries = self.q_proj.forward(decoder_state).reshape(shape) - queries /= queries.shape[-1] ** 0.5 keys, values = keys_values[:batch_count], keys_values[batch_count:] decoder_state = super().forward(keys, values, queries, attention_mask) return decoder_state, keys_values @@ -150,8 +142,7 @@ class DalleBartDecoderTorch(nn.Module): self.keys_values_state_shape = ( layer_count * 2 * batch_count, image_token_count, - attention_head_count, - embed_count // attention_head_count + embed_count ) self.zero_prob = torch.zeros([1]) self.token_indices = torch.arange(self.sample_token_count) @@ -188,7 +179,6 @@ class DalleBartDecoderTorch(nn.Module): token_index[:1] ) keys_values.append(keys_values_layer) - keys_values = torch.cat(keys_values, dim=0) decoder_state = self.final_ln(decoder_state) logits = self.lm_head(decoder_state) a = self.condition_factor @@ -200,7 +190,7 @@ class DalleBartDecoderTorch(nn.Module): self.zero_prob, torch.exp(logits - top_logits[0]) ) - return probs, keys_values + return probs, torch.cat(keys_values) def forward( diff --git a/min_dalle/models/dalle_bart_encoder_torch.py b/min_dalle/models/dalle_bart_encoder_torch.py index afd6295..296cdec 100644 --- a/min_dalle/models/dalle_bart_encoder_torch.py +++ b/min_dalle/models/dalle_bart_encoder_torch.py @@ -44,6 +44,11 @@ class AttentionTorch(nn.Module): queries: FloatTensor, attention_mask: BoolTensor ) -> FloatTensor: + keys = keys.reshape(keys.shape[:2] + (self.head_count, -1)) + values = values.reshape(values.shape[:2] + (self.head_count, -1)) + queries = queries.reshape(queries.shape[:2] + (self.head_count, -1)) + queries /= queries.shape[-1] ** 0.5 + attention_bias = torch.where( attention_mask, self.one * 0, @@ -73,11 +78,9 @@ class EncoderSelfAttentionTorch(AttentionTorch): encoder_state: FloatTensor, attention_mask: BoolTensor ) -> FloatTensor: - shape_split = encoder_state.shape[:2] + (self.head_count, -1) - keys = self.k_proj.forward(encoder_state).reshape(shape_split) - values = self.v_proj.forward(encoder_state).reshape(shape_split) - queries = self.q_proj.forward(encoder_state).reshape(shape_split) - queries /= queries.shape[-1] ** 0.5 + keys = self.k_proj.forward(encoder_state) + values = self.v_proj.forward(encoder_state) + queries = self.q_proj.forward(encoder_state) return super().forward(keys, values, queries, attention_mask)