cuda in detokenizer from previous commit broke colab flax model, fixed

This commit is contained in:
Brett Kuprel 2022-06-28 21:36:48 -04:00
parent d846cab1b6
commit 764b0bc685

View File

@ -110,7 +110,7 @@ def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray:
params = load_vqgan_torch_params(model_path) params = load_vqgan_torch_params(model_path)
detokenizer = VQGanDetokenizer() detokenizer = VQGanDetokenizer()
detokenizer.load_state_dict(params) detokenizer.load_state_dict(params)
if torch.cuda.is_available(): detokenizer = detokenizer.cuda() # if torch.cuda.is_available(): detokenizer = detokenizer.cuda()
image = detokenizer.forward(image_tokens).to(torch.uint8) image = detokenizer.forward(image_tokens).to(torch.uint8)
del detokenizer, params del detokenizer, params
return image.to('cpu').detach().numpy() return image.to('cpu').detach().numpy()