fixed bug on replicate

This commit is contained in:
Brett Kuprel
2022-07-15 21:21:16 -04:00
parent cde7900adc
commit 00d9363172
2 changed files with 6 additions and 5 deletions

View File

@@ -161,16 +161,16 @@ class DalleBartDecoder(nn.Module):
)
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
temperature = settings[0]
top_k = settings[1].to(torch.long)
supercondition_factor = settings[2]
temperature = settings[[0]]
top_k = settings[[1]].to(torch.long)
supercondition_factor = settings[[2]]
logits = logits[:, -1, : 2 ** 14]
logits: FloatTensor = (
logits[:image_count] * (1 - supercondition_factor) +
logits[image_count:] * supercondition_factor
)
logits_sorted, _ = logits.sort(descending=True)
is_kept = logits >= logits_sorted[:, top_k: top_k + 1]
is_kept = logits >= logits_sorted[:, top_k - 1]
logits -= logits_sorted[:, [0]]
logits /= temperature
logits.exp_()