|
|
|
@ -161,16 +161,16 @@ class DalleBartDecoder(nn.Module): |
|
|
|
|
) |
|
|
|
|
decoder_state = self.final_ln(decoder_state) |
|
|
|
|
logits = self.lm_head(decoder_state) |
|
|
|
|
temperature = settings[0] |
|
|
|
|
top_k = settings[1].to(torch.long) |
|
|
|
|
supercondition_factor = settings[2] |
|
|
|
|
temperature = settings[[0]] |
|
|
|
|
top_k = settings[[1]].to(torch.long) |
|
|
|
|
supercondition_factor = settings[[2]] |
|
|
|
|
logits = logits[:, -1, : 2 ** 14] |
|
|
|
|
logits: FloatTensor = ( |
|
|
|
|
logits[:image_count] * (1 - supercondition_factor) + |
|
|
|
|
logits[image_count:] * supercondition_factor |
|
|
|
|
) |
|
|
|
|
logits_sorted, _ = logits.sort(descending=True) |
|
|
|
|
is_kept = logits >= logits_sorted[:, top_k: top_k + 1] |
|
|
|
|
is_kept = logits >= logits_sorted[:, top_k - 1] |
|
|
|
|
logits -= logits_sorted[:, [0]] |
|
|
|
|
logits /= temperature |
|
|
|
|
logits.exp_() |
|
|
|
|