fixed bug on replicate
This commit is contained in:
parent
cde7900adc
commit
00d9363172
|
@ -161,16 +161,16 @@ class DalleBartDecoder(nn.Module):
|
|||
)
|
||||
decoder_state = self.final_ln(decoder_state)
|
||||
logits = self.lm_head(decoder_state)
|
||||
temperature = settings[0]
|
||||
top_k = settings[1].to(torch.long)
|
||||
supercondition_factor = settings[2]
|
||||
temperature = settings[[0]]
|
||||
top_k = settings[[1]].to(torch.long)
|
||||
supercondition_factor = settings[[2]]
|
||||
logits = logits[:, -1, : 2 ** 14]
|
||||
logits: FloatTensor = (
|
||||
logits[:image_count] * (1 - supercondition_factor) +
|
||||
logits[image_count:] * supercondition_factor
|
||||
)
|
||||
logits_sorted, _ = logits.sort(descending=True)
|
||||
is_kept = logits >= logits_sorted[:, top_k: top_k + 1]
|
||||
is_kept = logits >= logits_sorted[:, top_k - 1]
|
||||
logits -= logits_sorted[:, [0]]
|
||||
logits /= temperature
|
||||
logits.exp_()
|
||||
|
|
|
@ -11,7 +11,8 @@ class ReplicatePredictor(BasePredictor):
|
|||
self.model = MinDalle(
|
||||
is_mega=True,
|
||||
is_reusable=True,
|
||||
dtype=torch.float32
|
||||
dtype=torch.float32,
|
||||
device='cuda'
|
||||
)
|
||||
|
||||
def predict(
|
||||
|
|
Loading…
Reference in New Issue
Block a user