From dba3f11b3f2b5904f9b666c10fb999996824c791 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Sat, 9 Jul 2022 06:48:51 -0400 Subject: [PATCH] faster inference with cuda/cudnn backends flags --- README.md | 2 +- cog.yaml | 8 ++++---- min_dalle/min_dalle.py | 12 +++++++++--- min_dalle/models/dalle_bart_decoder.py | 3 +-- min_dalle/models/dalle_bart_encoder.py | 1 - min_dalle/models/vqgan_detokenizer.py | 1 - replicate_predictor.py | 4 +++- setup.py | 2 +- 8 files changed, 19 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 2ae0c83..0b973c1 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ This is a fast, minimal port of Boris Dayma's [DALL·E Mega](https://github.com/ To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a T4 in Colab - 48 sec with a P100 in Colab -- 14 sec with an A100 on Replicate +- 13 sec with an A100 on Replicate The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax). diff --git a/cog.yaml b/cog.yaml index 9d2e90b..bed81dc 100644 --- a/cog.yaml +++ b/cog.yaml @@ -1,13 +1,13 @@ build: - cuda: "11.4" + cuda: "11.5.1" gpu: true python_version: "3.10" system_packages: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.3.5" + - "min-dalle==0.3.7" run: - - pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html + - pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html -predict: "replicate_predictor.py:ReplicatePredictor" +predict: "replicate_predictor.py:ReplicatePredictor" \ No newline at end of file diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index fedbf3a..9b02bba 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -4,15 +4,21 @@ import numpy from torch import LongTensor, FloatTensor from math import sqrt import torch +import torch.backends.cudnn, torch.backends.cuda import json import requests from typing import Iterator -torch.set_grad_enabled(False) -torch.set_num_threads(os.cpu_count()) - from .text_tokenizer import TextTokenizer from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer +torch.set_grad_enabled(False) +torch.set_num_threads(os.cpu_count()) +torch.backends.cudnn.enabled = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cudnn.benchmark = True +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/' diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index 917f354..9e75e4f 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -1,8 +1,6 @@ from typing import Tuple, List import torch from torch import nn, LongTensor, FloatTensor, BoolTensor -torch.set_grad_enabled(False) - from .dalle_bart_encoder import GLU, AttentionBase IMAGE_TOKEN_COUNT = 256 @@ -180,6 +178,7 @@ class DalleBartDecoder(nn.Module): self.zero_prob, torch.exp(logits - top_logits[:, [0]]) ) + probs[:, 2 ** 14:] = 0 # vqgan vocab_count is only 2 ** 14 return probs, attention_state diff --git a/min_dalle/models/dalle_bart_encoder.py b/min_dalle/models/dalle_bart_encoder.py index 7081fa3..cb93389 100644 --- a/min_dalle/models/dalle_bart_encoder.py +++ b/min_dalle/models/dalle_bart_encoder.py @@ -1,7 +1,6 @@ from typing import List import torch from torch import nn, BoolTensor, FloatTensor, LongTensor -torch.set_grad_enabled(False) class GLU(nn.Module): diff --git a/min_dalle/models/vqgan_detokenizer.py b/min_dalle/models/vqgan_detokenizer.py index c3ffebb..9c9ee48 100644 --- a/min_dalle/models/vqgan_detokenizer.py +++ b/min_dalle/models/vqgan_detokenizer.py @@ -1,7 +1,6 @@ import torch from torch import FloatTensor, LongTensor from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding -torch.set_grad_enabled(False) class ResnetBlock(Module): diff --git a/replicate_predictor.py b/replicate_predictor.py index 5177c5d..39018ec 100644 --- a/replicate_predictor.py +++ b/replicate_predictor.py @@ -1,9 +1,11 @@ from min_dalle import MinDalle import tempfile -import torch +import torch, torch.backends.cudnn from typing import Iterator from cog import BasePredictor, Path, Input +torch.backends.cudnn.deterministic = False + class ReplicatePredictor(BasePredictor): def setup(self): diff --git a/setup.py b/setup.py index 0c8b658..b408af3 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALL·E)', # long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.3.5', + version='0.3.7', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',