faster inference with cuda/cudnn backends flags

main
Brett Kuprel 2 years ago
parent 703bfb231d
commit dba3f11b3f
  1. 2
      README.md
  2. 8
      cog.yaml
  3. 12
      min_dalle/min_dalle.py
  4. 3
      min_dalle/models/dalle_bart_decoder.py
  5. 1
      min_dalle/models/dalle_bart_encoder.py
  6. 1
      min_dalle/models/vqgan_detokenizer.py
  7. 4
      replicate_predictor.py
  8. 2
      setup.py

2
README.md vendored

@ -11,7 +11,7 @@ This is a fast, minimal port of Boris Dayma's [DALL·E Mega](https://github.com/
To generate a 4x4 grid of DALL·E Mega images it takes: To generate a 4x4 grid of DALL·E Mega images it takes:
- 89 sec with a T4 in Colab - 89 sec with a T4 in Colab
- 48 sec with a P100 in Colab - 48 sec with a P100 in Colab
- 14 sec with an A100 on Replicate - 13 sec with an A100 on Replicate
The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax). The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax).

8
cog.yaml vendored

@ -1,13 +1,13 @@
build: build:
cuda: "11.4" cuda: "11.5.1"
gpu: true gpu: true
python_version: "3.10" python_version: "3.10"
system_packages: system_packages:
- "libgl1-mesa-glx" - "libgl1-mesa-glx"
- "libglib2.0-0" - "libglib2.0-0"
python_packages: python_packages:
- "min-dalle==0.3.5" - "min-dalle==0.3.7"
run: run:
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html - pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
predict: "replicate_predictor.py:ReplicatePredictor" predict: "replicate_predictor.py:ReplicatePredictor"

@ -4,15 +4,21 @@ import numpy
from torch import LongTensor, FloatTensor from torch import LongTensor, FloatTensor
from math import sqrt from math import sqrt
import torch import torch
import torch.backends.cudnn, torch.backends.cuda
import json import json
import requests import requests
from typing import Iterator from typing import Iterator
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
from .text_tokenizer import TextTokenizer from .text_tokenizer import TextTokenizer
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/' MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'

@ -1,8 +1,6 @@
from typing import Tuple, List from typing import Tuple, List
import torch import torch
from torch import nn, LongTensor, FloatTensor, BoolTensor from torch import nn, LongTensor, FloatTensor, BoolTensor
torch.set_grad_enabled(False)
from .dalle_bart_encoder import GLU, AttentionBase from .dalle_bart_encoder import GLU, AttentionBase
IMAGE_TOKEN_COUNT = 256 IMAGE_TOKEN_COUNT = 256
@ -180,6 +178,7 @@ class DalleBartDecoder(nn.Module):
self.zero_prob, self.zero_prob,
torch.exp(logits - top_logits[:, [0]]) torch.exp(logits - top_logits[:, [0]])
) )
probs[:, 2 ** 14:] = 0 # vqgan vocab_count is only 2 ** 14
return probs, attention_state return probs, attention_state

@ -1,7 +1,6 @@
from typing import List from typing import List
import torch import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False)
class GLU(nn.Module): class GLU(nn.Module):

@ -1,7 +1,6 @@
import torch import torch
from torch import FloatTensor, LongTensor from torch import FloatTensor, LongTensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.set_grad_enabled(False)
class ResnetBlock(Module): class ResnetBlock(Module):

@ -1,9 +1,11 @@
from min_dalle import MinDalle from min_dalle import MinDalle
import tempfile import tempfile
import torch import torch, torch.backends.cudnn
from typing import Iterator from typing import Iterator
from cog import BasePredictor, Path, Input from cog import BasePredictor, Path, Input
torch.backends.cudnn.deterministic = False
class ReplicatePredictor(BasePredictor): class ReplicatePredictor(BasePredictor):
def setup(self): def setup(self):

@ -5,7 +5,7 @@ setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
# long_description=(Path(__file__).parent / "README.rst").read_text(), # long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.3.5', version='0.3.7',
author='Brett Kuprel', author='Brett Kuprel',
author_email='brkuprel@gmail.com', author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle', url='https://github.com/kuprel/min-dalle',

Loading…
Cancel
Save