Merge pull request #24 from TheFutureGadgetsLab/main

Fixed disabling of gradients in the torch code
This commit is contained in:
kuprel 2022-06-28 21:10:35 -04:00 committed by GitHub
commit b8c4173181
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@ import numpy
from typing import Dict from typing import Dict
from torch import LongTensor, FloatTensor from torch import LongTensor, FloatTensor
import torch import torch
torch.no_grad() torch.set_grad_enabled(False)
from .models.vqgan_detokenizer import VQGanDetokenizer from .models.vqgan_detokenizer import VQGanDetokenizer
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch