save converted detokenizer params
This commit is contained in:
@@ -37,7 +37,8 @@ class DecoderSelfAttentionFlax(AttentionFlax):
|
||||
state_index
|
||||
)
|
||||
batch_count = decoder_state.shape[0]
|
||||
keys, values = attention_state[:batch_count], attention_state[batch_count:]
|
||||
keys = attention_state[:batch_count]
|
||||
values = attention_state[batch_count:]
|
||||
|
||||
decoder_state = self.forward(
|
||||
keys,
|
||||
@@ -120,7 +121,7 @@ class SampleState:
|
||||
attention_state: jnp.ndarray
|
||||
|
||||
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
||||
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||
return (1 - a) * logits[0, -1] + a * logits[1, -1]
|
||||
|
||||
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
|
||||
top_logits, _ = lax.top_k(logits, k)
|
||||
|
@@ -184,7 +184,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
decoder_state = self.final_ln(decoder_state)
|
||||
logits = self.lm_head(decoder_state)
|
||||
a = self.condition_factor
|
||||
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||
logits: FloatTensor = (1 - a) * logits[0, -1] + a * logits[1, -1]
|
||||
|
||||
top_logits, _ = logits.topk(50, dim=-1)
|
||||
probs = torch.where(
|
||||
|
Reference in New Issue
Block a user