Fixed disabling of gradients in the torch code

main
Haydn Jones 2 years ago
parent 9d6b6dcc92
commit a3a247e6ec
  1. 2
      min_dalle/min_dalle_torch.py

@ -2,7 +2,7 @@ import numpy
from typing import Dict
from torch import LongTensor, FloatTensor
import torch
torch.no_grad()
torch.set_grad_enabled(False)
from .models.vqgan_detokenizer import VQGanDetokenizer
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch

Loading…
Cancel
Save