simplified flax attention and matched torch attention
This commit is contained in:
parent
61cc99c13c
commit
d99828a239
2
README.md
vendored
2
README.md
vendored
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).
|
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).
|
||||||
|
|
||||||
DALL·E Mega inference with PyTorch takes 7.3 seconds in Colab to generate an avocado armchair
|
It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
|
|
|
@ -13,15 +13,9 @@ class DecoderCrossAttentionFlax(AttentionFlax):
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
attention_mask: jnp.ndarray,
|
attention_mask: jnp.ndarray,
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
keys: jnp.ndarray = self.k_proj(encoder_state)
|
keys = self.k_proj(encoder_state)
|
||||||
values: jnp.ndarray = self.v_proj(encoder_state)
|
values = self.v_proj(encoder_state)
|
||||||
queries: jnp.ndarray = self.q_proj(decoder_state)
|
queries = self.q_proj(decoder_state)
|
||||||
query_shape = queries.shape[:2] + (self.head_count, -1)
|
|
||||||
key_value_shape = keys.shape[:2] + (self.head_count, -1)
|
|
||||||
keys = keys.reshape(key_value_shape)
|
|
||||||
values = values.reshape(key_value_shape)
|
|
||||||
queries = queries.reshape(query_shape)
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
return self.forward(keys, values, queries, attention_mask)
|
return self.forward(keys, values, queries, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,31 +23,29 @@ class DecoderSelfAttentionFlax(AttentionFlax):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
decoder_state: jnp.ndarray,
|
decoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
attention_state: jnp.ndarray,
|
||||||
values_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray,
|
attention_mask: jnp.ndarray,
|
||||||
state_index: tuple
|
state_index: tuple
|
||||||
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||||
shape_split = decoder_state.shape[:2] + (self.head_count, -1)
|
keys = self.k_proj(decoder_state)
|
||||||
keys_state = lax.dynamic_update_slice(
|
values = self.v_proj(decoder_state)
|
||||||
keys_state,
|
queries = self.q_proj(decoder_state)
|
||||||
self.k_proj(decoder_state).reshape(shape_split),
|
|
||||||
|
attention_state = lax.dynamic_update_slice(
|
||||||
|
attention_state,
|
||||||
|
jnp.concatenate([keys, values]),
|
||||||
state_index
|
state_index
|
||||||
)
|
)
|
||||||
values_state = lax.dynamic_update_slice(
|
batch_count = decoder_state.shape[0]
|
||||||
values_state,
|
keys, values = attention_state[:batch_count], attention_state[batch_count:]
|
||||||
self.v_proj(decoder_state).reshape(shape_split),
|
|
||||||
state_index
|
|
||||||
)
|
|
||||||
queries = self.q_proj(decoder_state).reshape(shape_split)
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
decoder_state = self.forward(
|
decoder_state = self.forward(
|
||||||
keys_state,
|
keys,
|
||||||
values_state,
|
values,
|
||||||
queries,
|
queries,
|
||||||
attention_mask
|
attention_mask
|
||||||
)
|
)
|
||||||
return decoder_state, (keys_state, values_state)
|
return decoder_state, attention_state
|
||||||
|
|
||||||
|
|
||||||
class DalleBartDecoderLayerFlax(nn.Module):
|
class DalleBartDecoderLayerFlax(nn.Module):
|
||||||
|
@ -82,11 +74,10 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
||||||
self,
|
self,
|
||||||
decoder_state: jnp.ndarray,
|
decoder_state: jnp.ndarray,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
attention_state: jnp.ndarray,
|
||||||
values_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray,
|
attention_mask: jnp.ndarray,
|
||||||
token_index: int
|
token_index: int
|
||||||
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = decoder_state
|
residual = decoder_state
|
||||||
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
|
decoder_state = self.pre_self_attn_layer_norm(decoder_state)
|
||||||
|
@ -94,12 +85,11 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
||||||
jnp.arange(self.image_token_count) < token_index + 1,
|
jnp.arange(self.image_token_count) < token_index + 1,
|
||||||
(decoder_state.shape[0], 1)
|
(decoder_state.shape[0], 1)
|
||||||
)
|
)
|
||||||
decoder_state, keys_values_state = self.self_attn(
|
decoder_state, attention_state = self.self_attn(
|
||||||
decoder_state,
|
decoder_state,
|
||||||
keys_state,
|
attention_state,
|
||||||
values_state,
|
|
||||||
self_attention_mask,
|
self_attention_mask,
|
||||||
(0, token_index, 0, 0)
|
(0, token_index, 0)
|
||||||
)
|
)
|
||||||
decoder_state = self.self_attn_layer_norm(decoder_state)
|
decoder_state = self.self_attn_layer_norm(decoder_state)
|
||||||
decoder_state = residual + decoder_state
|
decoder_state = residual + decoder_state
|
||||||
|
@ -120,15 +110,14 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
||||||
decoder_state = self.glu(decoder_state)
|
decoder_state = self.glu(decoder_state)
|
||||||
decoder_state = residual + decoder_state
|
decoder_state = residual + decoder_state
|
||||||
|
|
||||||
return decoder_state, keys_values_state
|
return decoder_state, attention_state
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
class SampleState:
|
class SampleState:
|
||||||
prev_token: jnp.ndarray
|
prev_token: jnp.ndarray
|
||||||
prng_key: jnp.ndarray
|
prng_key: jnp.ndarray
|
||||||
keys_state: jnp.ndarray
|
attention_state: jnp.ndarray
|
||||||
values_state: jnp.ndarray
|
|
||||||
|
|
||||||
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
||||||
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||||
|
@ -161,8 +150,8 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
DalleBartDecoderLayerFlax,
|
DalleBartDecoderLayerFlax,
|
||||||
variable_axes = { "params": 0, "cache": 0 },
|
variable_axes = { "params": 0, "cache": 0 },
|
||||||
split_rngs = { "params": True },
|
split_rngs = { "params": True },
|
||||||
in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast),
|
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
|
||||||
out_axes = (0, 0),
|
out_axes = 0,
|
||||||
length=self.layer_count,
|
length=self.layer_count,
|
||||||
)(
|
)(
|
||||||
self.image_token_count,
|
self.image_token_count,
|
||||||
|
@ -178,28 +167,26 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
keys_state: jnp.ndarray,
|
attention_state: jnp.ndarray,
|
||||||
values_state: jnp.ndarray,
|
|
||||||
attention_mask: jnp.ndarray,
|
attention_mask: jnp.ndarray,
|
||||||
prev_token: int,
|
prev_token: int,
|
||||||
token_index: int
|
token_index: int
|
||||||
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||||
batch_count = encoder_state.shape[0]
|
batch_count = encoder_state.shape[0]
|
||||||
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
|
ones = jnp.ones((batch_count, 1), dtype=jnp.int32)
|
||||||
decoder_state = self.embed_tokens(prev_token * ones)
|
decoder_state = self.embed_tokens(prev_token * ones)
|
||||||
decoder_state += self.embed_positions(token_index * ones)
|
decoder_state += self.embed_positions(token_index * ones)
|
||||||
decoder_state = self.layernorm_embedding(decoder_state)
|
decoder_state = self.layernorm_embedding(decoder_state)
|
||||||
decoder_state, (keys_state, values_state) = self.layers(
|
decoder_state, attention_state = self.layers(
|
||||||
decoder_state,
|
decoder_state,
|
||||||
encoder_state,
|
encoder_state,
|
||||||
keys_state,
|
attention_state,
|
||||||
values_state,
|
|
||||||
attention_mask,
|
attention_mask,
|
||||||
token_index
|
token_index
|
||||||
)
|
)
|
||||||
decoder_state = self.final_ln(decoder_state)
|
decoder_state = self.final_ln(decoder_state)
|
||||||
decoder_state = self.lm_head(decoder_state)
|
decoder_state = self.lm_head(decoder_state)
|
||||||
return decoder_state, keys_state, values_state
|
return decoder_state, attention_state
|
||||||
|
|
||||||
def sample_image_tokens(
|
def sample_image_tokens(
|
||||||
self,
|
self,
|
||||||
|
@ -213,12 +200,11 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
def sample_next_image_token(
|
def sample_next_image_token(
|
||||||
state: SampleState,
|
state: SampleState,
|
||||||
token_index: int
|
token_index: int
|
||||||
) -> Tuple[SampleState, None]:
|
) -> Tuple[SampleState, jnp.ndarray]:
|
||||||
logits, keys_state, values_state = self.apply(
|
logits, attention_state = self.apply(
|
||||||
{ 'params': params },
|
{ 'params': params },
|
||||||
encoder_state = encoder_state,
|
encoder_state = encoder_state,
|
||||||
keys_state = state.keys_state,
|
attention_state = state.attention_state,
|
||||||
values_state = state.values_state,
|
|
||||||
attention_mask = attention_mask,
|
attention_mask = attention_mask,
|
||||||
prev_token = state.prev_token,
|
prev_token = state.prev_token,
|
||||||
token_index = token_index
|
token_index = token_index
|
||||||
|
@ -233,26 +219,23 @@ class DalleBartDecoderFlax(nn.Module):
|
||||||
state = SampleState(
|
state = SampleState(
|
||||||
prev_token = next_token,
|
prev_token = next_token,
|
||||||
prng_key = prng_key_next,
|
prng_key = prng_key_next,
|
||||||
keys_state = keys_state,
|
attention_state = attention_state
|
||||||
values_state = values_state
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return state, next_token
|
return state, next_token
|
||||||
|
|
||||||
batch_count = encoder_state.shape[0]
|
batch_count = encoder_state.shape[0]
|
||||||
state_shape = (
|
attention_state_shape = (
|
||||||
self.layer_count,
|
self.layer_count,
|
||||||
batch_count,
|
batch_count * 2,
|
||||||
self.image_token_count,
|
self.image_token_count,
|
||||||
self.attention_head_count,
|
self.embed_count
|
||||||
self.embed_count // self.attention_head_count
|
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_state = SampleState(
|
initial_state = SampleState(
|
||||||
prev_token = self.start_token,
|
prev_token = self.start_token,
|
||||||
prng_key = prng_key,
|
prng_key = prng_key,
|
||||||
keys_state = jnp.zeros(state_shape),
|
attention_state = jnp.zeros(attention_state_shape)
|
||||||
values_state = jnp.zeros(state_shape)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_, image_tokens = lax.scan(
|
_, image_tokens = lax.scan(
|
||||||
|
|
|
@ -23,22 +23,22 @@ class DecoderSelfAttentionTorch(AttentionTorch):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
decoder_state: FloatTensor,
|
decoder_state: FloatTensor,
|
||||||
keys_values: FloatTensor,
|
attention_state: FloatTensor,
|
||||||
attention_mask: BoolTensor,
|
attention_mask: BoolTensor,
|
||||||
token_mask: BoolTensor
|
token_mask: BoolTensor
|
||||||
) -> Tuple[FloatTensor, FloatTensor]:
|
) -> Tuple[FloatTensor, FloatTensor]:
|
||||||
batch_count = decoder_state.shape[0]
|
|
||||||
keys = self.k_proj.forward(decoder_state)
|
keys = self.k_proj.forward(decoder_state)
|
||||||
values = self.v_proj.forward(decoder_state)
|
values = self.v_proj.forward(decoder_state)
|
||||||
queries = self.q_proj.forward(decoder_state)
|
queries = self.q_proj.forward(decoder_state)
|
||||||
keys_values = torch.where(
|
attention_state = torch.where(
|
||||||
token_mask[None, :, None],
|
token_mask[None, :, None],
|
||||||
torch.cat([keys, values]),
|
torch.cat([keys, values]),
|
||||||
keys_values
|
attention_state
|
||||||
)
|
)
|
||||||
keys, values = keys_values[:batch_count], keys_values[batch_count:]
|
batch_count = decoder_state.shape[0]
|
||||||
|
keys, values = attention_state[:batch_count], attention_state[batch_count:]
|
||||||
decoder_state = super().forward(keys, values, queries, attention_mask)
|
decoder_state = super().forward(keys, values, queries, attention_mask)
|
||||||
return decoder_state, keys_values
|
return decoder_state, attention_state
|
||||||
|
|
||||||
|
|
||||||
class DecoderLayerTorch(nn.Module):
|
class DecoderLayerTorch(nn.Module):
|
||||||
|
@ -67,7 +67,7 @@ class DecoderLayerTorch(nn.Module):
|
||||||
self,
|
self,
|
||||||
decoder_state: FloatTensor,
|
decoder_state: FloatTensor,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
keys_values_state: FloatTensor,
|
attention_state: FloatTensor,
|
||||||
attention_mask: BoolTensor,
|
attention_mask: BoolTensor,
|
||||||
token_index: LongTensor
|
token_index: LongTensor
|
||||||
) -> Tuple[FloatTensor, FloatTensor]:
|
) -> Tuple[FloatTensor, FloatTensor]:
|
||||||
|
@ -77,9 +77,9 @@ class DecoderLayerTorch(nn.Module):
|
||||||
self_attn_mask = self.token_indices < token_index + 1
|
self_attn_mask = self.token_indices < token_index + 1
|
||||||
token_mask = self.token_indices == token_index
|
token_mask = self.token_indices == token_index
|
||||||
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
|
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
|
||||||
decoder_state, keys_values_state = self.self_attn.forward(
|
decoder_state, attention_state = self.self_attn.forward(
|
||||||
decoder_state,
|
decoder_state,
|
||||||
keys_values_state,
|
attention_state,
|
||||||
self_attn_mask,
|
self_attn_mask,
|
||||||
token_mask
|
token_mask
|
||||||
)
|
)
|
||||||
|
@ -102,7 +102,7 @@ class DecoderLayerTorch(nn.Module):
|
||||||
decoder_state = self.glu.forward(decoder_state)
|
decoder_state = self.glu.forward(decoder_state)
|
||||||
decoder_state = residual + decoder_state
|
decoder_state = residual + decoder_state
|
||||||
|
|
||||||
return decoder_state, keys_values_state
|
return decoder_state, attention_state
|
||||||
|
|
||||||
|
|
||||||
class DalleBartDecoderTorch(nn.Module):
|
class DalleBartDecoderTorch(nn.Module):
|
||||||
|
@ -139,8 +139,9 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||||
self.final_ln = nn.LayerNorm(embed_count)
|
self.final_ln = nn.LayerNorm(embed_count)
|
||||||
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
|
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
|
||||||
self.keys_values_state_shape = (
|
self.attention_state_shape = (
|
||||||
layer_count * 2 * batch_count,
|
layer_count,
|
||||||
|
2 * batch_count,
|
||||||
image_token_count,
|
image_token_count,
|
||||||
embed_count
|
embed_count
|
||||||
)
|
)
|
||||||
|
@ -157,7 +158,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
self,
|
self,
|
||||||
text_tokens: LongTensor,
|
text_tokens: LongTensor,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
keys_values_state: FloatTensor,
|
attention_state: FloatTensor,
|
||||||
prev_token_and_index: LongTensor
|
prev_token_and_index: LongTensor
|
||||||
) -> Tuple[LongTensor, FloatTensor]:
|
) -> Tuple[LongTensor, FloatTensor]:
|
||||||
attention_mask = text_tokens.not_equal(1)
|
attention_mask = text_tokens.not_equal(1)
|
||||||
|
@ -168,17 +169,16 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
decoder_state += self.embed_positions.forward(token_index)
|
decoder_state += self.embed_positions.forward(token_index)
|
||||||
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
||||||
decoder_state = decoder_state[:, None]
|
decoder_state = decoder_state[:, None]
|
||||||
keys_values = []
|
attention_states_new = []
|
||||||
for i, layer in enumerate(self.layers):
|
for i in range(self.layer_count):
|
||||||
j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count
|
decoder_state, attention_state_layer = self.layers[i].forward(
|
||||||
decoder_state, keys_values_layer = layer.forward(
|
|
||||||
decoder_state,
|
decoder_state,
|
||||||
encoder_state,
|
encoder_state,
|
||||||
keys_values_state[j1:j2],
|
attention_state[i],
|
||||||
attention_mask,
|
attention_mask,
|
||||||
token_index[:1]
|
token_index[:1]
|
||||||
)
|
)
|
||||||
keys_values.append(keys_values_layer)
|
attention_states_new.append(attention_state_layer)
|
||||||
decoder_state = self.final_ln(decoder_state)
|
decoder_state = self.final_ln(decoder_state)
|
||||||
logits = self.lm_head(decoder_state)
|
logits = self.lm_head(decoder_state)
|
||||||
a = self.condition_factor
|
a = self.condition_factor
|
||||||
|
@ -190,7 +190,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
self.zero_prob,
|
self.zero_prob,
|
||||||
torch.exp(logits - top_logits[0])
|
torch.exp(logits - top_logits[0])
|
||||||
)
|
)
|
||||||
return probs, torch.cat(keys_values)
|
return probs, torch.stack(attention_states_new)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -199,17 +199,17 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
encoder_state: FloatTensor
|
encoder_state: FloatTensor
|
||||||
) -> LongTensor:
|
) -> LongTensor:
|
||||||
image_tokens: List[LongTensor] = []
|
image_tokens: List[LongTensor] = []
|
||||||
keys_values_state = torch.zeros(self.keys_values_state_shape)
|
attention_state = torch.zeros(self.attention_state_shape)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
keys_values_state = keys_values_state.cuda()
|
attention_state = attention_state.cuda()
|
||||||
image_token = self.start_token
|
image_token = self.start_token
|
||||||
|
|
||||||
for i in range(self.sample_token_count):
|
for i in range(self.sample_token_count):
|
||||||
token_index = self.token_indices[i:i+1]
|
token_index = self.token_indices[i:i+1]
|
||||||
probs, keys_values_state = self.decode_step(
|
probs, attention_state = self.decode_step(
|
||||||
text_tokens = text_tokens,
|
text_tokens = text_tokens,
|
||||||
encoder_state = encoder_state,
|
encoder_state = encoder_state,
|
||||||
keys_values_state = keys_values_state,
|
attention_state = attention_state,
|
||||||
prev_token_and_index = torch.cat([image_token, token_index])
|
prev_token_and_index = torch.cat([image_token, token_index])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,10 @@ class AttentionFlax(nn.Module):
|
||||||
queries: jnp.ndarray,
|
queries: jnp.ndarray,
|
||||||
attention_mask: jnp.ndarray
|
attention_mask: jnp.ndarray
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
|
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
|
||||||
|
values = values.reshape(values.shape[:2] + (self.head_count, -1))
|
||||||
|
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
|
||||||
|
queries /= queries.shape[-1] ** 0.5
|
||||||
attention_bias: jnp.ndarray = lax.select(
|
attention_bias: jnp.ndarray = lax.select(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
jnp.full(attention_mask.shape, 0.0),
|
jnp.full(attention_mask.shape, 0.0),
|
||||||
|
@ -70,11 +74,9 @@ class EncoderSelfAttentionFlax(AttentionFlax):
|
||||||
encoder_state: jnp.ndarray,
|
encoder_state: jnp.ndarray,
|
||||||
attention_mask: jnp.ndarray
|
attention_mask: jnp.ndarray
|
||||||
) -> jnp.ndarray:
|
) -> jnp.ndarray:
|
||||||
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
|
keys = self.k_proj(encoder_state)
|
||||||
keys = self.k_proj(encoder_state).reshape(shape_split)
|
values = self.v_proj(encoder_state)
|
||||||
values = self.v_proj(encoder_state).reshape(shape_split)
|
queries = self.q_proj(encoder_state)
|
||||||
queries = self.q_proj(encoder_state).reshape(shape_split)
|
|
||||||
queries /= queries.shape[-1] ** 0.5
|
|
||||||
return self.forward(keys, values, queries, attention_mask)
|
return self.forward(keys, values, queries, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user