forgot missing 2**

This commit is contained in:
Brett Kuprel
2022-07-04 23:29:48 -04:00
parent 976f438879
commit ccdcbc7d46
5 changed files with 5 additions and 5 deletions

View File

@@ -164,7 +164,7 @@ class DalleBartDecoder(nn.Module):
)
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
a = log2_supercondition_factor
a = 2 ** log2_supercondition_factor
logits: FloatTensor = (
logits[:image_count, -1] * (1 - a) +
logits[image_count:, -1] * a