faster inference with cuda/cudnn backends flags

This commit is contained in:
Brett Kuprel
2022-07-09 06:48:51 -04:00
parent 703bfb231d
commit dba3f11b3f
8 changed files with 19 additions and 14 deletions

View File

@@ -1,8 +1,6 @@
from typing import Tuple, List
import torch
from torch import nn, LongTensor, FloatTensor, BoolTensor
torch.set_grad_enabled(False)
from .dalle_bart_encoder import GLU, AttentionBase
IMAGE_TOKEN_COUNT = 256
@@ -180,6 +178,7 @@ class DalleBartDecoder(nn.Module):
self.zero_prob,
torch.exp(logits - top_logits[:, [0]])
)
probs[:, 2 ** 14:] = 0 # vqgan vocab_count is only 2 ** 14
return probs, attention_state

View File

@@ -1,7 +1,6 @@
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False)
class GLU(nn.Module):

View File

@@ -1,7 +1,6 @@
import torch
from torch import FloatTensor, LongTensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.set_grad_enabled(False)
class ResnetBlock(Module):