diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 5a6a39d..215181b 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -110,7 +110,7 @@ def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray: params = load_vqgan_torch_params(model_path) detokenizer = VQGanDetokenizer() detokenizer.load_state_dict(params) - if torch.cuda.is_available(): detokenizer = detokenizer.cuda() + # if torch.cuda.is_available(): detokenizer = detokenizer.cuda() image = detokenizer.forward(image_tokens).to(torch.uint8) del detokenizer, params return image.to('cpu').detach().numpy()