diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 16390fb..525ef9f 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -45,8 +45,8 @@ class MinDalleTorch(MinDalleBase): is_encoder=True ) self.encoder.load_state_dict(params, strict=False) - if torch.cuda.is_available(): self.encoder = self.encoder.cuda() del params + if torch.cuda.is_available(): self.encoder = self.encoder.cuda() def init_decoder(self): @@ -69,8 +69,8 @@ class MinDalleTorch(MinDalleBase): is_encoder=False ) self.decoder.load_state_dict(params, strict=False) - if torch.cuda.is_available(): self.decoder = self.decoder.cuda() del params + if torch.cuda.is_available(): self.decoder = self.decoder.cuda() def init_detokenizer(self):