diff --git a/min_dalle/load_params.py b/min_dalle/load_params.py index b5c0332..ef081d8 100644 --- a/min_dalle/load_params.py +++ b/min_dalle/load_params.py @@ -101,10 +101,6 @@ def convert_dalle_bart_torch_from_flax_params( k = i.replace(j, 'layers.' + str(l)) P[k] = P[i][l] P.pop(i) - - for i in list(P): - if '_proj' in i: - P[i] = P[i][:, :, None, None] P['embed_tokens.weight'] = P.pop('embed_tokens.embedding') P['embed_positions.weight'] = P.pop('embed_positions.embedding') diff --git a/min_dalle/models/dalle_bart_decoder_torch.py b/min_dalle/models/dalle_bart_decoder_torch.py index e2a79e5..f9f2789 100644 --- a/min_dalle/models/dalle_bart_decoder_torch.py +++ b/min_dalle/models/dalle_bart_decoder_torch.py @@ -15,6 +15,12 @@ class DecoderCrossAttentionTorch(AttentionTorch): keys = self.k_proj.forward(encoder_state) values = self.v_proj.forward(encoder_state) queries = self.q_proj.forward(decoder_state) + query_shape = queries.shape[:2] + (self.head_count, -1) + key_value_shape = keys.shape[:2] + (self.head_count, -1) + keys = keys.reshape(key_value_shape) + values = values.reshape(key_value_shape) + queries = queries.reshape(query_shape) + queries /= queries.shape[-1] ** 0.5 return super().forward(keys, values, queries, attention_mask) @@ -24,20 +30,21 @@ class DecoderSelfAttentionTorch(AttentionTorch): keys_values: FloatTensor, attention_mask: BoolTensor, token_index: LongTensor - ) -> Tuple[FloatTensor, FloatTensor, FloatTensor]: - keys = self.k_proj.forward(decoder_state) - values = self.v_proj.forward(decoder_state) - queries = self.q_proj.forward(decoder_state) - + ) -> Tuple[FloatTensor, FloatTensor]: batch_count = decoder_state.shape[0] - token_count = keys_values.shape[-1] + token_count = keys_values.shape[1] + shape = (batch_count, 1) + keys_values.shape[2:] + keys = self.k_proj.forward(decoder_state).view(shape) + values = self.v_proj.forward(decoder_state).view(shape) + token_mask = torch.arange(token_count) == token_index keys_values = torch.where( - (torch.arange(token_count) == token_index)[None, None, :], - torch.cat([keys, values]).squeeze(2), + token_mask[None, :, None, None], + torch.cat([keys, values]), keys_values ) - keys, values = keys_values[:batch_count, :, None], keys_values[batch_count:, :, None] - + queries = self.q_proj.forward(decoder_state).reshape(shape) + queries /= queries.shape[-1] ** 0.5 + keys, values = keys_values[:batch_count], keys_values[batch_count:] decoder_state = super().forward(keys, values, queries, attention_mask) return decoder_state, keys_values @@ -71,28 +78,23 @@ class DecoderLayerTorch(nn.Module): decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) self_attn_mask = torch.arange(self.image_token_count) < token_index + 1 self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0]) - decoder_state = decoder_state.transpose(1, 2).unsqueeze(2) - # print("decoder_state", decoder_state.shape) decoder_state, keys_values_state = self.self_attn.forward( decoder_state, keys_values_state, self_attn_mask, token_index ) - decoder_state = decoder_state.transpose(1, 3).squeeze(2) decoder_state = self.self_attn_layer_norm.forward(decoder_state) decoder_state = residual + decoder_state # Cross Attention residual = decoder_state decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state) - decoder_state = decoder_state.transpose(1, 2).unsqueeze(2) decoder_state = self.encoder_attn.forward( decoder_state, encoder_state, attention_mask ) - decoder_state = decoder_state.transpose(1, 3).squeeze(2) decoder_state = self.encoder_attn_layer_norm.forward(decoder_state) decoder_state = residual + decoder_state @@ -140,9 +142,10 @@ class DalleBartDecoderTorch(nn.Module): self.final_ln = nn.LayerNorm(embed_count) self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False) self.keys_values_state_shape = ( - layer_count * 2 * batch_count, - embed_count, - image_token_count + layer_count * 2 * batch_count, + image_token_count, + attention_head_count, + embed_count // attention_head_count ) @@ -159,7 +162,7 @@ class DalleBartDecoderTorch(nn.Module): decoder_state = self.embed_tokens.forward(prev_token) decoder_state += self.embed_positions.forward(token_index) decoder_state = self.layernorm_embedding.forward(decoder_state) - decoder_state = decoder_state[:, None] # (batch_count, 1, embed_count) + decoder_state = decoder_state[:, None] keys_values = [] for i, layer in enumerate(self.layers): j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count @@ -172,8 +175,8 @@ class DalleBartDecoderTorch(nn.Module): ) keys_values.append(keys_values_layer) keys_values = torch.cat(keys_values, dim=0) - decoder_state = self.final_ln(decoder_state) # (batch_count, 1, embed_count) - logits = self.lm_head(decoder_state) # (batch_count, 1, vocab_size) + decoder_state = self.final_ln(decoder_state) + logits = self.lm_head(decoder_state) a = self.condition_factor logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1] @@ -193,7 +196,6 @@ class DalleBartDecoderTorch(nn.Module): image_tokens: List[LongTensor] = [] keys_values_state = torch.zeros(self.keys_values_state_shape) image_token = self.start_token - encoder_state = encoder_state.transpose(1, 2).unsqueeze(2) for i in range(self.sample_token_count): token_index = torch.tensor([i]).to(torch.long) diff --git a/min_dalle/models/dalle_bart_encoder_torch.py b/min_dalle/models/dalle_bart_encoder_torch.py index 9e481f3..001e126 100644 --- a/min_dalle/models/dalle_bart_encoder_torch.py +++ b/min_dalle/models/dalle_bart_encoder_torch.py @@ -26,12 +26,11 @@ class AttentionTorch(nn.Module): super().__init__() self.head_count = head_count self.embed_count = embed_count - self.head_dim = embed_count // head_count - self.k_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False) - self.v_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False) - self.q_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False) - self.out_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False) + self.k_proj = nn.Linear(embed_count, embed_count, bias=False) + self.v_proj = nn.Linear(embed_count, embed_count, bias=False) + self.q_proj = nn.Linear(embed_count, embed_count, bias=False) + self.out_proj = nn.Linear(embed_count, embed_count, bias=False) def forward(self, keys: FloatTensor, @@ -39,47 +38,27 @@ class AttentionTorch(nn.Module): queries: FloatTensor, attention_mask: BoolTensor ) -> FloatTensor: - batch_count = keys.shape[0] - - # b(hc)1q -> bqhc - # print(keys.shape, "keys", values.shape, "values", queries.shape, "queries") - keys = keys.transpose(1, 3) - keys = keys.reshape(keys.shape[:2] + (self.head_count, -1)) - - # b(hc)1q -> bchq - shape = (batch_count, self.head_count, self.head_dim, -1) - values = values.reshape(shape) - values = values.transpose(1, 2) - queries = queries.reshape(shape) - queries = queries.transpose(1, 2) - - # print(keys.shape, "keys", values.shape, "values", queries.shape, "queries") - attention_bias = torch.where( attention_mask, - torch.zeros([1, 1]), - torch.ones([1, 1]) * (-torch.inf), + torch.full(attention_mask.shape, 0.0), + torch.full(attention_mask.shape, -torch.inf), ) attention_weights: FloatTensor = torch.einsum( - 'bchq,bkhc->bkhq', - queries / self.head_dim ** 0.5, + 'bqhc,bkhc->bhqk', + queries, keys ) - attention_weights += attention_bias[:, :, None, None] - attention_weights = torch.softmax(attention_weights, 1) - # print(attention_weights.shape, "attention_weights") - hidden_state: FloatTensor = torch.einsum( - "bkhq,bchk->bchq", + attention_weights += attention_bias[:, None, None, :] + attention_weights = torch.softmax(attention_weights, -1) + attention_output: FloatTensor = torch.einsum( + "bhqk,bkhc->bqhc", attention_weights, values ) - # bchq -> b(hc)1q - # print(hidden_state.shape, "hidden_state") - hidden_state = hidden_state.transpose(1, 2) - hidden_state = hidden_state.reshape(batch_count, self.embed_count, 1, -1) - hidden_state = self.out_proj.forward(hidden_state) - # print(hidden_state.shape, "hidden_state") - return hidden_state + shape = attention_output.shape[:2] + (self.embed_count,) + attention_output = attention_output.reshape(shape) + attention_output = self.out_proj.forward(attention_output) + return attention_output class EncoderSelfAttentionTorch(AttentionTorch): @@ -88,11 +67,11 @@ class EncoderSelfAttentionTorch(AttentionTorch): encoder_state: FloatTensor, attention_mask: BoolTensor ) -> FloatTensor: - encoder_state = encoder_state.transpose(1, 2).unsqueeze(2) - # print(encoder_state.shape, "encoder_state") - keys = self.k_proj.forward(encoder_state) - values = self.v_proj.forward(encoder_state) - queries = self.q_proj.forward(encoder_state) + shape_split = encoder_state.shape[:2] + (self.head_count, -1) + keys = self.k_proj.forward(encoder_state).reshape(shape_split) + values = self.v_proj.forward(encoder_state).reshape(shape_split) + queries = self.q_proj.forward(encoder_state).reshape(shape_split) + queries /= queries.shape[-1] ** 0.5 return super().forward(keys, values, queries, attention_mask) @@ -112,7 +91,6 @@ class EncoderLayerTorch(nn.Module): residual = encoder_state encoder_state = self.pre_self_attn_layer_norm.forward(encoder_state) encoder_state = self.self_attn.forward(encoder_state, attention_mask) - encoder_state = encoder_state.transpose(1, 3).squeeze(2) encoder_state = self.self_attn_layer_norm.forward(encoder_state) encoder_state = residual + encoder_state residual = encoder_state