fixed bug on replicate

main
Brett Kuprel 2 years ago
parent cde7900adc
commit 00d9363172
  1. 8
      min_dalle/models/dalle_bart_decoder.py
  2. 3
      replicate_predictor.py

@ -161,16 +161,16 @@ class DalleBartDecoder(nn.Module):
) )
decoder_state = self.final_ln(decoder_state) decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state) logits = self.lm_head(decoder_state)
temperature = settings[0] temperature = settings[[0]]
top_k = settings[1].to(torch.long) top_k = settings[[1]].to(torch.long)
supercondition_factor = settings[2] supercondition_factor = settings[[2]]
logits = logits[:, -1, : 2 ** 14] logits = logits[:, -1, : 2 ** 14]
logits: FloatTensor = ( logits: FloatTensor = (
logits[:image_count] * (1 - supercondition_factor) + logits[:image_count] * (1 - supercondition_factor) +
logits[image_count:] * supercondition_factor logits[image_count:] * supercondition_factor
) )
logits_sorted, _ = logits.sort(descending=True) logits_sorted, _ = logits.sort(descending=True)
is_kept = logits >= logits_sorted[:, top_k: top_k + 1] is_kept = logits >= logits_sorted[:, top_k - 1]
logits -= logits_sorted[:, [0]] logits -= logits_sorted[:, [0]]
logits /= temperature logits /= temperature
logits.exp_() logits.exp_()

@ -11,7 +11,8 @@ class ReplicatePredictor(BasePredictor):
self.model = MinDalle( self.model = MinDalle(
is_mega=True, is_mega=True,
is_reusable=True, is_reusable=True,
dtype=torch.float32 dtype=torch.float32,
device='cuda'
) )
def predict( def predict(

Loading…
Cancel
Save