fixed bug on replicate
This commit is contained in:
parent
cde7900adc
commit
00d9363172
|
@ -161,16 +161,16 @@ class DalleBartDecoder(nn.Module):
|
||||||
)
|
)
|
||||||
decoder_state = self.final_ln(decoder_state)
|
decoder_state = self.final_ln(decoder_state)
|
||||||
logits = self.lm_head(decoder_state)
|
logits = self.lm_head(decoder_state)
|
||||||
temperature = settings[0]
|
temperature = settings[[0]]
|
||||||
top_k = settings[1].to(torch.long)
|
top_k = settings[[1]].to(torch.long)
|
||||||
supercondition_factor = settings[2]
|
supercondition_factor = settings[[2]]
|
||||||
logits = logits[:, -1, : 2 ** 14]
|
logits = logits[:, -1, : 2 ** 14]
|
||||||
logits: FloatTensor = (
|
logits: FloatTensor = (
|
||||||
logits[:image_count] * (1 - supercondition_factor) +
|
logits[:image_count] * (1 - supercondition_factor) +
|
||||||
logits[image_count:] * supercondition_factor
|
logits[image_count:] * supercondition_factor
|
||||||
)
|
)
|
||||||
logits_sorted, _ = logits.sort(descending=True)
|
logits_sorted, _ = logits.sort(descending=True)
|
||||||
is_kept = logits >= logits_sorted[:, top_k: top_k + 1]
|
is_kept = logits >= logits_sorted[:, top_k - 1]
|
||||||
logits -= logits_sorted[:, [0]]
|
logits -= logits_sorted[:, [0]]
|
||||||
logits /= temperature
|
logits /= temperature
|
||||||
logits.exp_()
|
logits.exp_()
|
||||||
|
|
|
@ -11,7 +11,8 @@ class ReplicatePredictor(BasePredictor):
|
||||||
self.model = MinDalle(
|
self.model = MinDalle(
|
||||||
is_mega=True,
|
is_mega=True,
|
||||||
is_reusable=True,
|
is_reusable=True,
|
||||||
dtype=torch.float32
|
dtype=torch.float32,
|
||||||
|
device='cuda'
|
||||||
)
|
)
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user