From 764b0bc685ad73781f7a0ca0f5b11c58765c4539 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Tue, 28 Jun 2022 21:36:48 -0400 Subject: [PATCH] cuda in detokenizer from previous commit broke colab flax model, fixed --- min_dalle/min_dalle_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 5a6a39d..215181b 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -110,7 +110,7 @@ def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray: params = load_vqgan_torch_params(model_path) detokenizer = VQGanDetokenizer() detokenizer.load_state_dict(params) - if torch.cuda.is_available(): detokenizer = detokenizer.cuda() + # if torch.cuda.is_available(): detokenizer = detokenizer.cuda() image = detokenizer.forward(image_tokens).to(torch.uint8) del detokenizer, params return image.to('cpu').detach().numpy()