torch.no_grad(), cleanup
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
||||
from typing import List, Tuple
|
||||
torch.no_grad()
|
||||
|
||||
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
torch.no_grad()
|
||||
|
||||
|
||||
class GLUTorch(nn.Module):
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||
torch.no_grad()
|
||||
|
||||
BATCH_COUNT: int = 1
|
||||
|
||||
|
Reference in New Issue
Block a user