faster inference with cuda/cudnn backends flags
This commit is contained in:
parent
703bfb231d
commit
dba3f11b3f
2
README.md
vendored
2
README.md
vendored
|
@ -11,7 +11,7 @@ This is a fast, minimal port of Boris Dayma's [DALL·E Mega](https://github.com/
|
||||||
To generate a 4x4 grid of DALL·E Mega images it takes:
|
To generate a 4x4 grid of DALL·E Mega images it takes:
|
||||||
- 89 sec with a T4 in Colab
|
- 89 sec with a T4 in Colab
|
||||||
- 48 sec with a P100 in Colab
|
- 48 sec with a P100 in Colab
|
||||||
- 14 sec with an A100 on Replicate
|
- 13 sec with an A100 on Replicate
|
||||||
|
|
||||||
The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax).
|
The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax).
|
||||||
|
|
||||||
|
|
6
cog.yaml
vendored
6
cog.yaml
vendored
|
@ -1,13 +1,13 @@
|
||||||
build:
|
build:
|
||||||
cuda: "11.4"
|
cuda: "11.5.1"
|
||||||
gpu: true
|
gpu: true
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
system_packages:
|
system_packages:
|
||||||
- "libgl1-mesa-glx"
|
- "libgl1-mesa-glx"
|
||||||
- "libglib2.0-0"
|
- "libglib2.0-0"
|
||||||
python_packages:
|
python_packages:
|
||||||
- "min-dalle==0.3.5"
|
- "min-dalle==0.3.7"
|
||||||
run:
|
run:
|
||||||
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
- pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
|
||||||
predict: "replicate_predictor.py:ReplicatePredictor"
|
predict: "replicate_predictor.py:ReplicatePredictor"
|
|
@ -4,15 +4,21 @@ import numpy
|
||||||
from torch import LongTensor, FloatTensor
|
from torch import LongTensor, FloatTensor
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
import torch
|
import torch
|
||||||
|
import torch.backends.cudnn, torch.backends.cuda
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
torch.set_num_threads(os.cpu_count())
|
|
||||||
|
|
||||||
from .text_tokenizer import TextTokenizer
|
from .text_tokenizer import TextTokenizer
|
||||||
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
|
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
|
||||||
|
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
||||||
|
|
||||||
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, LongTensor, FloatTensor, BoolTensor
|
from torch import nn, LongTensor, FloatTensor, BoolTensor
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
|
|
||||||
from .dalle_bart_encoder import GLU, AttentionBase
|
from .dalle_bart_encoder import GLU, AttentionBase
|
||||||
|
|
||||||
IMAGE_TOKEN_COUNT = 256
|
IMAGE_TOKEN_COUNT = 256
|
||||||
|
@ -180,6 +178,7 @@ class DalleBartDecoder(nn.Module):
|
||||||
self.zero_prob,
|
self.zero_prob,
|
||||||
torch.exp(logits - top_logits[:, [0]])
|
torch.exp(logits - top_logits[:, [0]])
|
||||||
)
|
)
|
||||||
|
probs[:, 2 ** 14:] = 0 # vqgan vocab_count is only 2 ** 14
|
||||||
return probs, attention_state
|
return probs, attention_state
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
|
|
||||||
|
|
||||||
class GLU(nn.Module):
|
class GLU(nn.Module):
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import FloatTensor, LongTensor
|
from torch import FloatTensor, LongTensor
|
||||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
|
|
||||||
|
|
||||||
class ResnetBlock(Module):
|
class ResnetBlock(Module):
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
from min_dalle import MinDalle
|
from min_dalle import MinDalle
|
||||||
import tempfile
|
import tempfile
|
||||||
import torch
|
import torch, torch.backends.cudnn
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
from cog import BasePredictor, Path, Input
|
from cog import BasePredictor, Path, Input
|
||||||
|
|
||||||
|
torch.backends.cudnn.deterministic = False
|
||||||
|
|
||||||
|
|
||||||
class ReplicatePredictor(BasePredictor):
|
class ReplicatePredictor(BasePredictor):
|
||||||
def setup(self):
|
def setup(self):
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -5,7 +5,7 @@ setuptools.setup(
|
||||||
name='min-dalle',
|
name='min-dalle',
|
||||||
description = 'min(DALL·E)',
|
description = 'min(DALL·E)',
|
||||||
# long_description=(Path(__file__).parent / "README.rst").read_text(),
|
# long_description=(Path(__file__).parent / "README.rst").read_text(),
|
||||||
version='0.3.5',
|
version='0.3.7',
|
||||||
author='Brett Kuprel',
|
author='Brett Kuprel',
|
||||||
author_email='brkuprel@gmail.com',
|
author_email='brkuprel@gmail.com',
|
||||||
url='https://github.com/kuprel/min-dalle',
|
url='https://github.com/kuprel/min-dalle',
|
||||||
|
|
Loading…
Reference in New Issue
Block a user