faster inference with cuda/cudnn backends flags

This commit is contained in:
Brett Kuprel
2022-07-09 06:48:51 -04:00
parent 703bfb231d
commit dba3f11b3f
8 changed files with 19 additions and 14 deletions

View File

@@ -4,15 +4,21 @@ import numpy
from torch import LongTensor, FloatTensor
from math import sqrt
import torch
import torch.backends.cudnn, torch.backends.cuda
import json
import requests
from typing import Iterator
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
from .text_tokenizer import TextTokenizer
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'

View File

@@ -1,8 +1,6 @@
from typing import Tuple, List
import torch
from torch import nn, LongTensor, FloatTensor, BoolTensor
torch.set_grad_enabled(False)
from .dalle_bart_encoder import GLU, AttentionBase
IMAGE_TOKEN_COUNT = 256
@@ -180,6 +178,7 @@ class DalleBartDecoder(nn.Module):
self.zero_prob,
torch.exp(logits - top_logits[:, [0]])
)
probs[:, 2 ** 14:] = 0 # vqgan vocab_count is only 2 ** 14
return probs, attention_state

View File

@@ -1,7 +1,6 @@
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False)
class GLU(nn.Module):

View File

@@ -1,7 +1,6 @@
import torch
from torch import FloatTensor, LongTensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.set_grad_enabled(False)
class ResnetBlock(Module):