support bfloat16

This commit is contained in:
Brett Kuprel 2022-07-07 08:21:20 -04:00
parent 5f526e2109
commit da62298f06
9 changed files with 108 additions and 96 deletions

8
README.md vendored
View File

@ -12,7 +12,6 @@ To generate a 4x4 grid of DALL·E Mega images it takes:
- 89 sec with a T4 in Colab - 89 sec with a T4 in Colab
- 48 sec with a P100 in Colab - 48 sec with a P100 in Colab
- 14 sec with an A100 on Replicate - 14 sec with an A100 on Replicate
- TBD with an H100 (@NVIDIA?)
The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax). The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax).
@ -30,13 +29,14 @@ Load the model parameters once and reuse the model to generate multiple images.
from min_dalle import MinDalle from min_dalle import MinDalle
model = MinDalle( model = MinDalle(
models_root='./pretrained',
dtype=torch.float32,
is_mega=True, is_mega=True,
is_reusable=True, is_reusable=True
models_root='./pretrained'
) )
``` ```
The required models will be downloaded to `models_root` if they are not already there. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top $k$ most probable tokens. The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `dtype=torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens.
```python ```python
image = model.generate_image( image = model.generate_image(

89
README.rst vendored
View File

@ -1,16 +1,16 @@
min(DALL·E) min(DALL·E)
=========== ===========
|Open In Colab|   |Replicate|   |Join us on Discord| |Colab|   |Replicate|   |Discord|
This is a fast, minimal implementation of Boris Daymas `DALL·E This is a fast, minimal port of Boris Daymas `DALL·E
Mega <https://github.com/borisdayma/dalle-mini>`__. It has been stripped Mega <https://github.com/borisdayma/dalle-mini>`__. It has been stripped
down for inference and converted to PyTorch. The only third party down for inference and converted to PyTorch. The only third party
dependencies are numpy, requests, pillow and torch. dependencies are numpy, requests, pillow and torch.
To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a
T4 in Colab - 48 sec with a P100 in Colab - 14 sec with an A100 on T4 in Colab - 48 sec with a P100 in Colab - 14 sec with an A100 on
Replicate - TBD with an H100 (@NVIDIA?) Replicate
The flax model and code for converting it to torch can be found The flax model and code for converting it to torch can be found
`here <https://github.com/kuprel/min-dalle-flax>`__. `here <https://github.com/kuprel/min-dalle-flax>`__.
@ -32,47 +32,58 @@ images.
from min_dalle import MinDalle from min_dalle import MinDalle
model = MinDalle(is_mega=True, models_root='./pretrained') model = MinDalle(
is_mega=True,
is_reusable=True,
models_root='./pretrained'
)
The required models will be downloaded to ``models_root`` if they are The required models will be downloaded to ``models_root`` if they are
not already there. Once everything has finished initializing, call not already there. Once everything has finished initializing, call
``generate_image`` with some text and a seed as many times as you want. ``generate_image`` with some text as many times as you want. Use a
positive ``seed`` for reproducible results. Higher values for
``log2_supercondition_factor`` result in better agreement with the text
but a narrower variety of generated images. Every image token is sampled
from the top-:math:`k` most probable tokens.
.. code:: python .. code:: python
text = 'Dali painting of WALL·E' image = model.generate_image(
image = model.generate_image(text, seed=0, grid_size=4) text='Nuclear explosion broccoli',
seed=-1,
grid_size=4,
log2_k=6,
log2_supercondition_factor=5,
is_verbose=False
)
display(image) display(image)
Interactive
~~~~~~~~~~~
If the model is being used interactively (e.g. in a notebook)
``generate_image_stream`` can be used to generate a stream of images as
the model is decoding. The detokenizer adds a slight delay for each
image. Setting ``log2_mid_count`` to 3 results in a total of
``2 ** 3 = 8`` generated images. The only valid values for
``log2_mid_count`` are 0, 1, 2, 3, and 4. This is implemented in the
colab.
.. code:: python .. code:: python
text = 'Rusty Iron Man suit found abandoned in the woods being reclaimed by nature' image_stream = model.generate_image_stream(
image = model.generate_image(text, seed=0, grid_size=3) text='Dali painting of WALL·E',
display(image) seed=-1,
grid_size=3,
log2_mid_count=3,
log2_k=6,
log2_supercondition_factor=3,
is_verbose=False
)
.. code:: python for image in image_stream:
display(image)
text = 'court sketch of godzilla on trial'
image = model.generate_image(text, seed=6, grid_size=3)
display(image)
.. code:: python
text = 'a funeral at Whole Foods'
image = model.generate_image(text, seed=10, grid_size=3)
display(image)
.. code:: python
text = 'Jesus turning water into wine on Americas Got Talent'
image = model.generate_image(text, seed=2, grid_size=3)
display(image)
.. code:: python
text = 'cctv footage of Yoda robbing a liquor store'
image = model.generate_image(text, seed=0, grid_size=3)
display(image)
Command Line Command Line
~~~~~~~~~~~~ ~~~~~~~~~~~~
@ -81,15 +92,11 @@ Use ``image_from_text.py`` to generate images from the command line.
.. code:: bash .. code:: bash
$ python image_from_text.py --text='artificial intelligence' --no-mega --seed=7 $ python image_from_text.py --text='artificial intelligence' --no-mega
.. code:: bash .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg
$ python image_from_text.py --text='trail cam footage of gollum eating watermelon' --mega --seed=1 --grid-size=3
.. |Open In Colab| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb :target: https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb
.. |Replicate| image:: https://replicate.com/kuprel/min-dalle/badge .. |Replicate| image:: https://replicate.com/kuprel/min-dalle/badge
:target: https://replicate.com/kuprel/min-dalle :target: https://replicate.com/kuprel/min-dalle
.. |Join us on Discord| image:: https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white .. |Discord| image:: https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white
:target: https://discord.gg/xBPBXfcFHd :target: https://discord.com/channels/823813159592001537/912729332311556136

4
cog.yaml vendored
View File

@ -6,8 +6,8 @@ build:
- "libgl1-mesa-glx" - "libgl1-mesa-glx"
- "libglib2.0-0" - "libglib2.0-0"
python_packages: python_packages:
- "min-dalle==0.2.29" - "min-dalle==0.2.35"
run: run:
- pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html - pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
predict: "replicate_predictor.py:ReplicatePredictor" predict: "replicate_predictor.py:ReplicatePredictor"

26
min_dalle.ipynb vendored

File diff suppressed because one or more lines are too long

View File

@ -18,13 +18,15 @@ MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
class MinDalle: class MinDalle:
def __init__( def __init__(
self, self,
is_mega: bool,
is_reusable: bool = True,
models_root: str = 'pretrained', models_root: str = 'pretrained',
dtype: torch.dtype = torch.float32,
is_mega: bool = True,
is_reusable: bool = True,
is_verbose = True is_verbose = True
): ):
self.is_mega = is_mega self.is_mega = is_mega
self.is_reusable = is_reusable self.is_reusable = is_reusable
self.dtype = dtype
self.is_verbose = is_verbose self.is_verbose = is_verbose
self.text_token_count = 64 self.text_token_count = 64
self.layer_count = 24 if is_mega else 12 self.layer_count = 24 if is_mega else 12
@ -34,7 +36,6 @@ class MinDalle:
self.text_vocab_count = 50272 if is_mega else 50264 self.text_vocab_count = 50272 if is_mega else 50264
self.image_vocab_count = 16415 if is_mega else 16384 self.image_vocab_count = 16415 if is_mega else 16384
if self.is_verbose: print("initializing MinDalle")
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
dalle_path = os.path.join(models_root, model_name) dalle_path = os.path.join(models_root, model_name)
vqgan_path = os.path.join(models_root, 'vqgan') vqgan_path = os.path.join(models_root, 'vqgan')
@ -105,7 +106,7 @@ class MinDalle:
text_token_count = self.text_token_count, text_token_count = self.text_token_count,
text_vocab_count = self.text_vocab_count, text_vocab_count = self.text_vocab_count,
layer_count = self.layer_count layer_count = self.layer_count
) ).to(self.dtype).eval()
params = torch.load(self.encoder_params_path) params = torch.load(self.encoder_params_path)
self.encoder.load_state_dict(params, strict=False) self.encoder.load_state_dict(params, strict=False)
del params del params
@ -123,7 +124,7 @@ class MinDalle:
glu_embed_count = self.glu_embed_count, glu_embed_count = self.glu_embed_count,
layer_count = self.layer_count, layer_count = self.layer_count,
start_token = self.image_vocab_count start_token = self.image_vocab_count
) ).to(self.dtype).eval()
params = torch.load(self.decoder_params_path) params = torch.load(self.decoder_params_path)
self.decoder.load_state_dict(params, strict=False) self.decoder.load_state_dict(params, strict=False)
del params del params
@ -134,7 +135,7 @@ class MinDalle:
is_downloaded = os.path.exists(self.detoker_params_path) is_downloaded = os.path.exists(self.detoker_params_path)
if not is_downloaded: self.download_detokenizer() if not is_downloaded: self.download_detokenizer()
if self.is_verbose: print("initializing VQGanDetokenizer") if self.is_verbose: print("initializing VQGanDetokenizer")
self.detokenizer = VQGanDetokenizer() self.detokenizer = VQGanDetokenizer().to(self.dtype).eval()
params = torch.load(self.detoker_params_path) params = torch.load(self.detoker_params_path)
self.detokenizer.load_state_dict(params) self.detokenizer.load_state_dict(params)
del params del params
@ -184,38 +185,41 @@ class MinDalle:
if not self.is_reusable: self.init_encoder() if not self.is_reusable: self.init_encoder()
if is_verbose: print("encoding text tokens") if is_verbose: print("encoding text tokens")
encoder_state = self.encoder.forward(text_tokens) with torch.cuda.amp.autocast(dtype=self.dtype):
encoder_state = self.encoder.forward(text_tokens)
if not self.is_reusable: del self.encoder if not self.is_reusable: del self.encoder
if torch.cuda.is_available(): torch.cuda.empty_cache() if torch.cuda.is_available(): torch.cuda.empty_cache()
if not self.is_reusable: self.init_decoder() if not self.is_reusable: self.init_decoder()
encoder_state, attention_mask, attention_state, image_tokens = ( with torch.cuda.amp.autocast(dtype=self.dtype):
self.decoder.decode_initial( encoder_state, attention_mask, attention_state, image_tokens = (
seed, self.decoder.decode_initial(
grid_size ** 2, seed,
text_tokens, grid_size ** 2,
encoder_state text_tokens,
encoder_state
)
) )
)
row_count = 16 row_count = 16
for row_index in range(row_count): for row_index in range(row_count):
if is_verbose: if is_verbose:
print('sampling row {} of {}'.format(row_index + 1, row_count)) print('sampling row {} of {}'.format(row_index + 1, row_count))
attention_state, image_tokens = self.decoder.decode_row( with torch.cuda.amp.autocast(dtype=self.dtype):
row_index, attention_state, image_tokens = self.decoder.decode_row(
log2_k, row_index,
log2_supercondition_factor, log2_k,
encoder_state, log2_supercondition_factor,
attention_mask, encoder_state,
attention_state, attention_mask,
image_tokens attention_state,
) image_tokens
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: )
tokens = image_tokens[:, 1:] if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
image = self.image_from_tokens(grid_size, tokens, is_verbose) tokens = image_tokens[:, 1:]
yield image image = self.image_from_tokens(grid_size, tokens, is_verbose)
yield image
def generate_image( def generate_image(

View File

@ -40,7 +40,8 @@ class DecoderSelfAttention(AttentionBase):
queries = self.q_proj.forward(decoder_state) queries = self.q_proj.forward(decoder_state)
attn_mask = self.token_indices < token_index + 1 attn_mask = self.token_indices < token_index + 1
attn_mask = attn_mask[None][[0] * decoder_state.shape[0]] attn_mask = attn_mask[None][[0] * decoder_state.shape[0]]
attention_state[:, token_index] = torch.cat([keys, values]) attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
attention_state[:, token_index] = attn_state_new
batch_count = decoder_state.shape[0] batch_count = decoder_state.shape[0]
keys = attention_state[:batch_count] keys = attention_state[:batch_count]
values = attention_state[batch_count:] values = attention_state[batch_count:]

View File

@ -82,7 +82,7 @@ class Upsample(Module):
self.conv = Conv2d(n, n, 3, padding=1) self.conv = Conv2d(n, n, 3, padding=1)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
x = self.upsample.forward(x) x = self.upsample.forward(x.to(torch.float32))
x = self.conv.forward(x) x = self.conv.forward(x)
return x return x

View File

@ -19,9 +19,9 @@ class ReplicatePredictor(BasePredictor):
default=True default=True
), ),
grid_size: int = Input( grid_size: int = Input(
description='Size of the image grid', description='Size of the image grid. 4x4 takes about 15 seconds, 8x8 takes about 45 seconds',
ge=1, ge=1,
le=4, le=8,
default=4 default=4
), ),
log2_supercondition_factor: int = Input( log2_supercondition_factor: int = Input(

View File

@ -4,8 +4,8 @@ from pathlib import Path
setuptools.setup( setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
long_description=(Path(__file__).parent / "README.rst").read_text(), # long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.2.29', version='0.2.35',
author='Brett Kuprel', author='Brett Kuprel',
author_email='brkuprel@gmail.com', author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle', url='https://github.com/kuprel/min-dalle',
@ -15,8 +15,8 @@ setuptools.setup(
], ],
license='MIT', license='MIT',
install_requires=[ install_requires=[
'torch>=1.10.0', 'torch>=1.11',
'typing_extensions>=4.1.0', 'typing_extensions>=4.1',
'numpy>=1.21', 'numpy>=1.21',
'pillow>=7.1', 'pillow>=7.1',
'requests>=2.23' 'requests>=2.23'