diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index f2ee357..9ab707c 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -161,16 +161,16 @@ class DalleBartDecoder(nn.Module): ) decoder_state = self.final_ln(decoder_state) logits = self.lm_head(decoder_state) - temperature = settings[0] - top_k = settings[1].to(torch.long) - supercondition_factor = settings[2] + temperature = settings[[0]] + top_k = settings[[1]].to(torch.long) + supercondition_factor = settings[[2]] logits = logits[:, -1, : 2 ** 14] logits: FloatTensor = ( logits[:image_count] * (1 - supercondition_factor) + logits[image_count:] * supercondition_factor ) logits_sorted, _ = logits.sort(descending=True) - is_kept = logits >= logits_sorted[:, top_k: top_k + 1] + is_kept = logits >= logits_sorted[:, top_k - 1] logits -= logits_sorted[:, [0]] logits /= temperature logits.exp_() diff --git a/replicate_predictor.py b/replicate_predictor.py index 8e702e4..3bac013 100644 --- a/replicate_predictor.py +++ b/replicate_predictor.py @@ -11,7 +11,8 @@ class ReplicatePredictor(BasePredictor): self.model = MinDalle( is_mega=True, is_reusable=True, - dtype=torch.float32 + dtype=torch.float32, + device='cuda' ) def predict(