diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 690e02c..e8a8563 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -2,7 +2,7 @@ import numpy from typing import Dict from torch import LongTensor, FloatTensor import torch -torch.no_grad() +torch.set_grad_enabled(False) from .models.vqgan_detokenizer import VQGanDetokenizer from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch