Fixed disabling of gradients in the torch code

This commit is contained in:
Haydn Jones 2022-06-28 17:57:44 -06:00
parent 9d6b6dcc92
commit a3a247e6ec

View File

@ -2,7 +2,7 @@ import numpy
from typing import Dict from typing import Dict
from torch import LongTensor, FloatTensor from torch import LongTensor, FloatTensor
import torch import torch
torch.no_grad() torch.set_grad_enabled(False)
from .models.vqgan_detokenizer import VQGanDetokenizer from .models.vqgan_detokenizer import VQGanDetokenizer
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch