fixed bug on replicate

This commit is contained in:
Brett Kuprel 2022-07-15 21:21:16 -04:00
parent cde7900adc
commit 00d9363172
2 changed files with 6 additions and 5 deletions

View File

@ -161,16 +161,16 @@ class DalleBartDecoder(nn.Module):
)
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
temperature = settings[0]
top_k = settings[1].to(torch.long)
supercondition_factor = settings[2]
temperature = settings[[0]]
top_k = settings[[1]].to(torch.long)
supercondition_factor = settings[[2]]
logits = logits[:, -1, : 2 ** 14]
logits: FloatTensor = (
logits[:image_count] * (1 - supercondition_factor) +
logits[image_count:] * supercondition_factor
)
logits_sorted, _ = logits.sort(descending=True)
is_kept = logits >= logits_sorted[:, top_k: top_k + 1]
is_kept = logits >= logits_sorted[:, top_k - 1]
logits -= logits_sorted[:, [0]]
logits /= temperature
logits.exp_()

View File

@ -11,7 +11,8 @@ class ReplicatePredictor(BasePredictor):
self.model = MinDalle(
is_mega=True,
is_reusable=True,
dtype=torch.float32
dtype=torch.float32,
device='cuda'
)
def predict(