save converted detokenizer params

This commit is contained in:
Brett Kuprel
2022-07-01 10:17:29 -04:00
parent 8b5960b687
commit e4c2be54cb
7 changed files with 35 additions and 32 deletions

View File

@@ -37,7 +37,8 @@ class DecoderSelfAttentionFlax(AttentionFlax):
state_index
)
batch_count = decoder_state.shape[0]
keys, values = attention_state[:batch_count], attention_state[batch_count:]
keys = attention_state[:batch_count]
values = attention_state[batch_count:]
decoder_state = self.forward(
keys,
@@ -120,7 +121,7 @@ class SampleState:
attention_state: jnp.ndarray
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
return a * logits[0, -1] + (1 - a) * logits[1, -1]
return (1 - a) * logits[0, -1] + a * logits[1, -1]
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
top_logits, _ = lax.top_k(logits, k)

View File

@@ -184,7 +184,7 @@ class DalleBartDecoderTorch(nn.Module):
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
a = self.condition_factor
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
logits: FloatTensor = (1 - a) * logits[0, -1] + a * logits[1, -1]
top_logits, _ = logits.topk(50, dim=-1)
probs = torch.where(