Fixed disabling of gradients in the torch code
This commit is contained in:
parent
9d6b6dcc92
commit
a3a247e6ec
|
@ -2,7 +2,7 @@ import numpy
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from torch import LongTensor, FloatTensor
|
from torch import LongTensor, FloatTensor
|
||||||
import torch
|
import torch
|
||||||
torch.no_grad()
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||||
|
|
Loading…
Reference in New Issue
Block a user