faster inference with cuda/cudnn backends flags
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
from typing import Tuple, List
|
||||
import torch
|
||||
from torch import nn, LongTensor, FloatTensor, BoolTensor
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
from .dalle_bart_encoder import GLU, AttentionBase
|
||||
|
||||
IMAGE_TOKEN_COUNT = 256
|
||||
@@ -180,6 +178,7 @@ class DalleBartDecoder(nn.Module):
|
||||
self.zero_prob,
|
||||
torch.exp(logits - top_logits[:, [0]])
|
||||
)
|
||||
probs[:, 2 ** 14:] = 0 # vqgan vocab_count is only 2 ** 14
|
||||
return probs, attention_state
|
||||
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
from torch import FloatTensor, LongTensor
|
||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
class ResnetBlock(Module):
|
||||
|
Reference in New Issue
Block a user