cuda in detokenizer from previous commit broke colab flax model, fixed
This commit is contained in:
parent
d846cab1b6
commit
764b0bc685
|
@ -110,7 +110,7 @@ def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray:
|
||||||
params = load_vqgan_torch_params(model_path)
|
params = load_vqgan_torch_params(model_path)
|
||||||
detokenizer = VQGanDetokenizer()
|
detokenizer = VQGanDetokenizer()
|
||||||
detokenizer.load_state_dict(params)
|
detokenizer.load_state_dict(params)
|
||||||
if torch.cuda.is_available(): detokenizer = detokenizer.cuda()
|
# if torch.cuda.is_available(): detokenizer = detokenizer.cuda()
|
||||||
image = detokenizer.forward(image_tokens).to(torch.uint8)
|
image = detokenizer.forward(image_tokens).to(torch.uint8)
|
||||||
del detokenizer, params
|
del detokenizer, params
|
||||||
return image.to('cpu').detach().numpy()
|
return image.to('cpu').detach().numpy()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user