support bfloat16

This commit is contained in:
Brett Kuprel 2022-07-07 08:21:20 -04:00
parent 5f526e2109
commit da62298f06
9 changed files with 108 additions and 96 deletions

8
README.md vendored
View File

@ -12,7 +12,6 @@ To generate a 4x4 grid of DALL·E Mega images it takes:
- 89 sec with a T4 in Colab
- 48 sec with a P100 in Colab
- 14 sec with an A100 on Replicate
- TBD with an H100 (@NVIDIA?)
The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax).
@ -30,13 +29,14 @@ Load the model parameters once and reuse the model to generate multiple images.
from min_dalle import MinDalle
model = MinDalle(
models_root='./pretrained',
dtype=torch.float32,
is_mega=True,
is_reusable=True,
models_root='./pretrained'
is_reusable=True
)
```
The required models will be downloaded to `models_root` if they are not already there. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top $k$ most probable tokens.
The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `dtype=torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens.
```python
image = model.generate_image(

89
README.rst vendored
View File

@ -1,16 +1,16 @@
min(DALL·E)
===========
|Open In Colab|   |Replicate|   |Join us on Discord|
|Colab|   |Replicate|   |Discord|
This is a fast, minimal implementation of Boris Daymas `DALL·E
This is a fast, minimal port of Boris Daymas `DALL·E
Mega <https://github.com/borisdayma/dalle-mini>`__. It has been stripped
down for inference and converted to PyTorch. The only third party
dependencies are numpy, requests, pillow and torch.
To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a
T4 in Colab - 48 sec with a P100 in Colab - 14 sec with an A100 on
Replicate - TBD with an H100 (@NVIDIA?)
Replicate
The flax model and code for converting it to torch can be found
`here <https://github.com/kuprel/min-dalle-flax>`__.
@ -32,47 +32,58 @@ images.
from min_dalle import MinDalle
model = MinDalle(is_mega=True, models_root='./pretrained')
model = MinDalle(
is_mega=True,
is_reusable=True,
models_root='./pretrained'
)
The required models will be downloaded to ``models_root`` if they are
not already there. Once everything has finished initializing, call
``generate_image`` with some text and a seed as many times as you want.
``generate_image`` with some text as many times as you want. Use a
positive ``seed`` for reproducible results. Higher values for
``log2_supercondition_factor`` result in better agreement with the text
but a narrower variety of generated images. Every image token is sampled
from the top-:math:`k` most probable tokens.
.. code:: python
text = 'Dali painting of WALL·E'
image = model.generate_image(text, seed=0, grid_size=4)
image = model.generate_image(
text='Nuclear explosion broccoli',
seed=-1,
grid_size=4,
log2_k=6,
log2_supercondition_factor=5,
is_verbose=False
)
display(image)
Interactive
~~~~~~~~~~~
If the model is being used interactively (e.g. in a notebook)
``generate_image_stream`` can be used to generate a stream of images as
the model is decoding. The detokenizer adds a slight delay for each
image. Setting ``log2_mid_count`` to 3 results in a total of
``2 ** 3 = 8`` generated images. The only valid values for
``log2_mid_count`` are 0, 1, 2, 3, and 4. This is implemented in the
colab.
.. code:: python
text = 'Rusty Iron Man suit found abandoned in the woods being reclaimed by nature'
image = model.generate_image(text, seed=0, grid_size=3)
display(image)
image_stream = model.generate_image_stream(
text='Dali painting of WALL·E',
seed=-1,
grid_size=3,
log2_mid_count=3,
log2_k=6,
log2_supercondition_factor=3,
is_verbose=False
)
.. code:: python
text = 'court sketch of godzilla on trial'
image = model.generate_image(text, seed=6, grid_size=3)
display(image)
.. code:: python
text = 'a funeral at Whole Foods'
image = model.generate_image(text, seed=10, grid_size=3)
display(image)
.. code:: python
text = 'Jesus turning water into wine on Americas Got Talent'
image = model.generate_image(text, seed=2, grid_size=3)
display(image)
.. code:: python
text = 'cctv footage of Yoda robbing a liquor store'
image = model.generate_image(text, seed=0, grid_size=3)
display(image)
for image in image_stream:
display(image)
Command Line
~~~~~~~~~~~~
@ -81,15 +92,11 @@ Use ``image_from_text.py`` to generate images from the command line.
.. code:: bash
$ python image_from_text.py --text='artificial intelligence' --no-mega --seed=7
$ python image_from_text.py --text='artificial intelligence' --no-mega
.. code:: bash
$ python image_from_text.py --text='trail cam footage of gollum eating watermelon' --mega --seed=1 --grid-size=3
.. |Open In Colab| image:: https://colab.research.google.com/assets/colab-badge.svg
.. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb
.. |Replicate| image:: https://replicate.com/kuprel/min-dalle/badge
:target: https://replicate.com/kuprel/min-dalle
.. |Join us on Discord| image:: https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white
:target: https://discord.gg/xBPBXfcFHd
.. |Discord| image:: https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white
:target: https://discord.com/channels/823813159592001537/912729332311556136

4
cog.yaml vendored
View File

@ -6,8 +6,8 @@ build:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "min-dalle==0.2.29"
- "min-dalle==0.2.35"
run:
- pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
predict: "replicate_predictor.py:ReplicatePredictor"

28
min_dalle.ipynb vendored

File diff suppressed because one or more lines are too long

View File

@ -18,13 +18,15 @@ MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
class MinDalle:
def __init__(
self,
is_mega: bool,
is_reusable: bool = True,
models_root: str = 'pretrained',
dtype: torch.dtype = torch.float32,
is_mega: bool = True,
is_reusable: bool = True,
is_verbose = True
):
self.is_mega = is_mega
self.is_reusable = is_reusable
self.dtype = dtype
self.is_verbose = is_verbose
self.text_token_count = 64
self.layer_count = 24 if is_mega else 12
@ -34,7 +36,6 @@ class MinDalle:
self.text_vocab_count = 50272 if is_mega else 50264
self.image_vocab_count = 16415 if is_mega else 16384
if self.is_verbose: print("initializing MinDalle")
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
dalle_path = os.path.join(models_root, model_name)
vqgan_path = os.path.join(models_root, 'vqgan')
@ -105,7 +106,7 @@ class MinDalle:
text_token_count = self.text_token_count,
text_vocab_count = self.text_vocab_count,
layer_count = self.layer_count
)
).to(self.dtype).eval()
params = torch.load(self.encoder_params_path)
self.encoder.load_state_dict(params, strict=False)
del params
@ -123,7 +124,7 @@ class MinDalle:
glu_embed_count = self.glu_embed_count,
layer_count = self.layer_count,
start_token = self.image_vocab_count
)
).to(self.dtype).eval()
params = torch.load(self.decoder_params_path)
self.decoder.load_state_dict(params, strict=False)
del params
@ -134,7 +135,7 @@ class MinDalle:
is_downloaded = os.path.exists(self.detoker_params_path)
if not is_downloaded: self.download_detokenizer()
if self.is_verbose: print("initializing VQGanDetokenizer")
self.detokenizer = VQGanDetokenizer()
self.detokenizer = VQGanDetokenizer().to(self.dtype).eval()
params = torch.load(self.detoker_params_path)
self.detokenizer.load_state_dict(params)
del params
@ -184,38 +185,41 @@ class MinDalle:
if not self.is_reusable: self.init_encoder()
if is_verbose: print("encoding text tokens")
encoder_state = self.encoder.forward(text_tokens)
with torch.cuda.amp.autocast(dtype=self.dtype):
encoder_state = self.encoder.forward(text_tokens)
if not self.is_reusable: del self.encoder
if torch.cuda.is_available(): torch.cuda.empty_cache()
if not self.is_reusable: self.init_decoder()
encoder_state, attention_mask, attention_state, image_tokens = (
self.decoder.decode_initial(
seed,
grid_size ** 2,
text_tokens,
encoder_state
with torch.cuda.amp.autocast(dtype=self.dtype):
encoder_state, attention_mask, attention_state, image_tokens = (
self.decoder.decode_initial(
seed,
grid_size ** 2,
text_tokens,
encoder_state
)
)
)
row_count = 16
for row_index in range(row_count):
if is_verbose:
print('sampling row {} of {}'.format(row_index + 1, row_count))
attention_state, image_tokens = self.decoder.decode_row(
row_index,
log2_k,
log2_supercondition_factor,
encoder_state,
attention_mask,
attention_state,
image_tokens
)
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
tokens = image_tokens[:, 1:]
image = self.image_from_tokens(grid_size, tokens, is_verbose)
yield image
with torch.cuda.amp.autocast(dtype=self.dtype):
attention_state, image_tokens = self.decoder.decode_row(
row_index,
log2_k,
log2_supercondition_factor,
encoder_state,
attention_mask,
attention_state,
image_tokens
)
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
tokens = image_tokens[:, 1:]
image = self.image_from_tokens(grid_size, tokens, is_verbose)
yield image
def generate_image(

View File

@ -40,7 +40,8 @@ class DecoderSelfAttention(AttentionBase):
queries = self.q_proj.forward(decoder_state)
attn_mask = self.token_indices < token_index + 1
attn_mask = attn_mask[None][[0] * decoder_state.shape[0]]
attention_state[:, token_index] = torch.cat([keys, values])
attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
attention_state[:, token_index] = attn_state_new
batch_count = decoder_state.shape[0]
keys = attention_state[:batch_count]
values = attention_state[batch_count:]

View File

@ -82,7 +82,7 @@ class Upsample(Module):
self.conv = Conv2d(n, n, 3, padding=1)
def forward(self, x: Tensor) -> Tensor:
x = self.upsample.forward(x)
x = self.upsample.forward(x.to(torch.float32))
x = self.conv.forward(x)
return x

View File

@ -19,9 +19,9 @@ class ReplicatePredictor(BasePredictor):
default=True
),
grid_size: int = Input(
description='Size of the image grid',
description='Size of the image grid. 4x4 takes about 15 seconds, 8x8 takes about 45 seconds',
ge=1,
le=4,
le=8,
default=4
),
log2_supercondition_factor: int = Input(

View File

@ -4,8 +4,8 @@ from pathlib import Path
setuptools.setup(
name='min-dalle',
description = 'min(DALL·E)',
long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.2.29',
# long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.2.35',
author='Brett Kuprel',
author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle',
@ -15,8 +15,8 @@ setuptools.setup(
],
license='MIT',
install_requires=[
'torch>=1.10.0',
'typing_extensions>=4.1.0',
'torch>=1.11',
'typing_extensions>=4.1',
'numpy>=1.21',
'pillow>=7.1',
'requests>=2.23'