faster inference with cuda/cudnn backends flags

This commit is contained in:
Brett Kuprel 2022-07-09 06:48:51 -04:00
parent 703bfb231d
commit dba3f11b3f
8 changed files with 19 additions and 14 deletions

2
README.md vendored
View File

@ -11,7 +11,7 @@ This is a fast, minimal port of Boris Dayma's [DALL·E Mega](https://github.com/
To generate a 4x4 grid of DALL·E Mega images it takes:
- 89 sec with a T4 in Colab
- 48 sec with a P100 in Colab
- 14 sec with an A100 on Replicate
- 13 sec with an A100 on Replicate
The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax).

6
cog.yaml vendored
View File

@ -1,13 +1,13 @@
build:
cuda: "11.4"
cuda: "11.5.1"
gpu: true
python_version: "3.10"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "min-dalle==0.3.5"
- "min-dalle==0.3.7"
run:
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
- pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
predict: "replicate_predictor.py:ReplicatePredictor"

View File

@ -4,15 +4,21 @@ import numpy
from torch import LongTensor, FloatTensor
from math import sqrt
import torch
import torch.backends.cudnn, torch.backends.cuda
import json
import requests
from typing import Iterator
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
from .text_tokenizer import TextTokenizer
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'

View File

@ -1,8 +1,6 @@
from typing import Tuple, List
import torch
from torch import nn, LongTensor, FloatTensor, BoolTensor
torch.set_grad_enabled(False)
from .dalle_bart_encoder import GLU, AttentionBase
IMAGE_TOKEN_COUNT = 256
@ -180,6 +178,7 @@ class DalleBartDecoder(nn.Module):
self.zero_prob,
torch.exp(logits - top_logits[:, [0]])
)
probs[:, 2 ** 14:] = 0 # vqgan vocab_count is only 2 ** 14
return probs, attention_state

View File

@ -1,7 +1,6 @@
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False)
class GLU(nn.Module):

View File

@ -1,7 +1,6 @@
import torch
from torch import FloatTensor, LongTensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.set_grad_enabled(False)
class ResnetBlock(Module):

View File

@ -1,9 +1,11 @@
from min_dalle import MinDalle
import tempfile
import torch
import torch, torch.backends.cudnn
from typing import Iterator
from cog import BasePredictor, Path, Input
torch.backends.cudnn.deterministic = False
class ReplicatePredictor(BasePredictor):
def setup(self):

View File

@ -5,7 +5,7 @@ setuptools.setup(
name='min-dalle',
description = 'min(DALL·E)',
# long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.3.5',
version='0.3.7',
author='Brett Kuprel',
author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle',