forgot missing 2**
This commit is contained in:
@@ -164,7 +164,7 @@ class DalleBartDecoder(nn.Module):
|
||||
)
|
||||
decoder_state = self.final_ln(decoder_state)
|
||||
logits = self.lm_head(decoder_state)
|
||||
a = log2_supercondition_factor
|
||||
a = 2 ** log2_supercondition_factor
|
||||
logits: FloatTensor = (
|
||||
logits[:image_count, -1] * (1 - a) +
|
||||
logits[image_count:, -1] * a
|
||||
|
Reference in New Issue
Block a user