From a3a247e6ec2629c73cf96ad22d3cd3c3ccaae5ef Mon Sep 17 00:00:00 2001 From: Haydn Jones Date: Tue, 28 Jun 2022 17:57:44 -0600 Subject: [PATCH] Fixed disabling of gradients in the torch code --- min_dalle/min_dalle_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 690e02c..e8a8563 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -2,7 +2,7 @@ import numpy from typing import Dict from torch import LongTensor, FloatTensor import torch -torch.no_grad() +torch.set_grad_enabled(False) from .models.vqgan_detokenizer import VQGanDetokenizer from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch