diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index c47ad39..9f12858 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -142,6 +142,7 @@ class MinDalle: params = torch.load(self.detoker_params_path) self.detokenizer.load_state_dict(params) del params + torch.cuda.empty_cache() if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() @@ -175,6 +176,7 @@ class MinDalle: encoder_state ) if not self.is_reusable: del self.decoder + torch.cuda.empty_cache() return image_tokens