support bfloat16
This commit is contained in:
		
							
								
								
									
										8
									
								
								README.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								README.md
									
									
									
									
										vendored
									
									
								
							| @@ -12,7 +12,6 @@ To generate a 4x4 grid of DALL·E Mega images it takes: | |||||||
| - 89 sec with a T4 in Colab | - 89 sec with a T4 in Colab | ||||||
| - 48 sec with a P100 in Colab | - 48 sec with a P100 in Colab | ||||||
| - 14 sec with an A100 on Replicate | - 14 sec with an A100 on Replicate | ||||||
| - TBD with an H100 (@NVIDIA?) |  | ||||||
|  |  | ||||||
| The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax). | The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax). | ||||||
|  |  | ||||||
| @@ -30,13 +29,14 @@ Load the model parameters once and reuse the model to generate multiple images. | |||||||
| from min_dalle import MinDalle | from min_dalle import MinDalle | ||||||
|  |  | ||||||
| model = MinDalle( | model = MinDalle( | ||||||
|  |     models_root='./pretrained', | ||||||
|  |     dtype=torch.float32, | ||||||
|     is_mega=True,  |     is_mega=True,  | ||||||
|     is_reusable=True, |     is_reusable=True | ||||||
|     models_root='./pretrained' |  | ||||||
| ) | ) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| The required models will be downloaded to `models_root` if they are not already there.  Once everything has finished initializing, call `generate_image` with some text as many times as you want.  Use a positive `seed` for reproducible results.  Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images.  Every image token is sampled from the top $k$ most probable tokens. | The required models will be downloaded to `models_root` if they are not already there.  If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory.  There is still an issue with `dtype=torch.float16` that needs to be sorted out.  Once everything has finished initializing, call `generate_image` with some text as many times as you want.  Use a positive `seed` for reproducible results.  Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images.  Every image token is sampled from the top-$k$ most probable tokens. | ||||||
|  |  | ||||||
| ```python | ```python | ||||||
| image = model.generate_image( | image = model.generate_image( | ||||||
|   | |||||||
							
								
								
									
										83
									
								
								README.rst
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										83
									
								
								README.rst
									
									
									
									
										vendored
									
									
								
							| @@ -1,16 +1,16 @@ | |||||||
| min(DALL·E) | min(DALL·E) | ||||||
| =========== | =========== | ||||||
|  |  | ||||||
| |Open In Colab|   |Replicate|   |Join us on Discord| | |Colab|   |Replicate|   |Discord| | ||||||
|  |  | ||||||
| This is a fast, minimal implementation of Boris Dayma’s `DALL·E | This is a fast, minimal port of Boris Dayma’s `DALL·E | ||||||
| Mega <https://github.com/borisdayma/dalle-mini>`__. It has been stripped | Mega <https://github.com/borisdayma/dalle-mini>`__. It has been stripped | ||||||
| down for inference and converted to PyTorch. The only third party | down for inference and converted to PyTorch. The only third party | ||||||
| dependencies are numpy, requests, pillow and torch. | dependencies are numpy, requests, pillow and torch. | ||||||
|  |  | ||||||
| To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a | To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a | ||||||
| T4 in Colab - 48 sec with a P100 in Colab - 14 sec with an A100 on | T4 in Colab - 48 sec with a P100 in Colab - 14 sec with an A100 on | ||||||
| Replicate - TBD with an H100 (@NVIDIA?) | Replicate | ||||||
|  |  | ||||||
| The flax model and code for converting it to torch can be found | The flax model and code for converting it to torch can be found | ||||||
| `here <https://github.com/kuprel/min-dalle-flax>`__. | `here <https://github.com/kuprel/min-dalle-flax>`__. | ||||||
| @@ -32,46 +32,57 @@ images. | |||||||
|  |  | ||||||
|    from min_dalle import MinDalle |    from min_dalle import MinDalle | ||||||
|  |  | ||||||
|    model = MinDalle(is_mega=True, models_root='./pretrained') |    model = MinDalle( | ||||||
|  |        is_mega=True,  | ||||||
|  |        is_reusable=True, | ||||||
|  |        models_root='./pretrained' | ||||||
|  |    ) | ||||||
|  |  | ||||||
| The required models will be downloaded to ``models_root`` if they are | The required models will be downloaded to ``models_root`` if they are | ||||||
| not already there. Once everything has finished initializing, call | not already there. Once everything has finished initializing, call | ||||||
| ``generate_image`` with some text and a seed as many times as you want. | ``generate_image`` with some text as many times as you want. Use a | ||||||
|  | positive ``seed`` for reproducible results. Higher values for | ||||||
|  | ``log2_supercondition_factor`` result in better agreement with the text | ||||||
|  | but a narrower variety of generated images. Every image token is sampled | ||||||
|  | from the top-:math:`k` most probable tokens. | ||||||
|  |  | ||||||
| .. code:: python | .. code:: python | ||||||
|  |  | ||||||
|    text = 'Dali painting of WALL·E' |    image = model.generate_image( | ||||||
|    image = model.generate_image(text, seed=0, grid_size=4) |        text='Nuclear explosion broccoli', | ||||||
|  |        seed=-1, | ||||||
|  |        grid_size=4, | ||||||
|  |        log2_k=6, | ||||||
|  |        log2_supercondition_factor=5, | ||||||
|  |        is_verbose=False | ||||||
|  |    ) | ||||||
|  |  | ||||||
|    display(image) |    display(image) | ||||||
|  |  | ||||||
| .. code:: python | Interactive | ||||||
|  | ~~~~~~~~~~~ | ||||||
|  |  | ||||||
|    text = 'Rusty Iron Man suit found abandoned in the woods being reclaimed by nature' | If the model is being used interactively (e.g. in a notebook) | ||||||
|    image = model.generate_image(text, seed=0, grid_size=3) | ``generate_image_stream`` can be used to generate a stream of images as | ||||||
|    display(image) | the model is decoding. The detokenizer adds a slight delay for each | ||||||
|  | image. Setting ``log2_mid_count`` to 3 results in a total of | ||||||
|  | ``2 ** 3 = 8`` generated images. The only valid values for | ||||||
|  | ``log2_mid_count`` are 0, 1, 2, 3, and 4. This is implemented in the | ||||||
|  | colab. | ||||||
|  |  | ||||||
| .. code:: python | .. code:: python | ||||||
|  |  | ||||||
|    text = 'court sketch of godzilla on trial' |    image_stream = model.generate_image_stream( | ||||||
|    image = model.generate_image(text, seed=6, grid_size=3) |        text='Dali painting of WALL·E', | ||||||
|    display(image) |        seed=-1, | ||||||
|  |        grid_size=3, | ||||||
|  |        log2_mid_count=3, | ||||||
|  |        log2_k=6, | ||||||
|  |        log2_supercondition_factor=3, | ||||||
|  |        is_verbose=False | ||||||
|  |    ) | ||||||
|  |  | ||||||
| .. code:: python |    for image in image_stream: | ||||||
|  |  | ||||||
|    text = 'a funeral at Whole Foods' |  | ||||||
|    image = model.generate_image(text, seed=10, grid_size=3) |  | ||||||
|    display(image) |  | ||||||
|  |  | ||||||
| .. code:: python |  | ||||||
|  |  | ||||||
|    text = 'Jesus turning water into wine on Americas Got Talent' |  | ||||||
|    image = model.generate_image(text, seed=2, grid_size=3) |  | ||||||
|    display(image) |  | ||||||
|  |  | ||||||
| .. code:: python |  | ||||||
|  |  | ||||||
|    text = 'cctv footage of Yoda robbing a liquor store' |  | ||||||
|    image = model.generate_image(text, seed=0, grid_size=3) |  | ||||||
|        display(image) |        display(image) | ||||||
|  |  | ||||||
| Command Line | Command Line | ||||||
| @@ -81,15 +92,11 @@ Use ``image_from_text.py`` to generate images from the command line. | |||||||
|  |  | ||||||
| .. code:: bash | .. code:: bash | ||||||
|  |  | ||||||
|    $ python image_from_text.py --text='artificial intelligence' --no-mega --seed=7 |    $ python image_from_text.py --text='artificial intelligence' --no-mega | ||||||
|  |  | ||||||
| .. code:: bash | .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg | ||||||
|  |  | ||||||
|    $ python image_from_text.py --text='trail cam footage of gollum eating watermelon' --mega --seed=1 --grid-size=3 |  | ||||||
|  |  | ||||||
| .. |Open In Colab| image:: https://colab.research.google.com/assets/colab-badge.svg |  | ||||||
|    :target: https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb |    :target: https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb | ||||||
| .. |Replicate| image:: https://replicate.com/kuprel/min-dalle/badge | .. |Replicate| image:: https://replicate.com/kuprel/min-dalle/badge | ||||||
|    :target: https://replicate.com/kuprel/min-dalle |    :target: https://replicate.com/kuprel/min-dalle | ||||||
| .. |Join us on Discord| image:: https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white | .. |Discord| image:: https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white | ||||||
|    :target: https://discord.gg/xBPBXfcFHd |    :target: https://discord.com/channels/823813159592001537/912729332311556136 | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								cog.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								cog.yaml
									
									
									
									
										vendored
									
									
								
							| @@ -6,8 +6,8 @@ build: | |||||||
|     - "libgl1-mesa-glx" |     - "libgl1-mesa-glx" | ||||||
|     - "libglib2.0-0" |     - "libglib2.0-0" | ||||||
|   python_packages: |   python_packages: | ||||||
|     - "min-dalle==0.2.29" |     - "min-dalle==0.2.35" | ||||||
|   run: |   run: | ||||||
|     - pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html |     - pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html | ||||||
|  |  | ||||||
| predict: "replicate_predictor.py:ReplicatePredictor" | predict: "replicate_predictor.py:ReplicatePredictor" | ||||||
|   | |||||||
							
								
								
									
										26
									
								
								min_dalle.ipynb
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										26
									
								
								min_dalle.ipynb
									
									
									
									
										vendored
									
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -18,13 +18,15 @@ MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/' | |||||||
| class MinDalle: | class MinDalle: | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         is_mega: bool,  |  | ||||||
|         is_reusable: bool = True, |  | ||||||
|         models_root: str = 'pretrained', |         models_root: str = 'pretrained', | ||||||
|  |         dtype: torch.dtype = torch.float32, | ||||||
|  |         is_mega: bool = True,  | ||||||
|  |         is_reusable: bool = True, | ||||||
|         is_verbose = True |         is_verbose = True | ||||||
|     ): |     ): | ||||||
|         self.is_mega = is_mega |         self.is_mega = is_mega | ||||||
|         self.is_reusable = is_reusable |         self.is_reusable = is_reusable | ||||||
|  |         self.dtype = dtype | ||||||
|         self.is_verbose = is_verbose |         self.is_verbose = is_verbose | ||||||
|         self.text_token_count = 64 |         self.text_token_count = 64 | ||||||
|         self.layer_count = 24 if is_mega else 12 |         self.layer_count = 24 if is_mega else 12 | ||||||
| @@ -34,7 +36,6 @@ class MinDalle: | |||||||
|         self.text_vocab_count = 50272 if is_mega else 50264 |         self.text_vocab_count = 50272 if is_mega else 50264 | ||||||
|         self.image_vocab_count = 16415 if is_mega else 16384 |         self.image_vocab_count = 16415 if is_mega else 16384 | ||||||
|  |  | ||||||
|         if self.is_verbose: print("initializing MinDalle") |  | ||||||
|         model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') |         model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') | ||||||
|         dalle_path = os.path.join(models_root, model_name) |         dalle_path = os.path.join(models_root, model_name) | ||||||
|         vqgan_path = os.path.join(models_root, 'vqgan') |         vqgan_path = os.path.join(models_root, 'vqgan') | ||||||
| @@ -105,7 +106,7 @@ class MinDalle: | |||||||
|             text_token_count = self.text_token_count, |             text_token_count = self.text_token_count, | ||||||
|             text_vocab_count = self.text_vocab_count, |             text_vocab_count = self.text_vocab_count, | ||||||
|             layer_count = self.layer_count |             layer_count = self.layer_count | ||||||
|         ) |         ).to(self.dtype).eval() | ||||||
|         params = torch.load(self.encoder_params_path) |         params = torch.load(self.encoder_params_path) | ||||||
|         self.encoder.load_state_dict(params, strict=False) |         self.encoder.load_state_dict(params, strict=False) | ||||||
|         del params |         del params | ||||||
| @@ -123,7 +124,7 @@ class MinDalle: | |||||||
|             glu_embed_count = self.glu_embed_count, |             glu_embed_count = self.glu_embed_count, | ||||||
|             layer_count = self.layer_count, |             layer_count = self.layer_count, | ||||||
|             start_token = self.image_vocab_count |             start_token = self.image_vocab_count | ||||||
|         ) |         ).to(self.dtype).eval() | ||||||
|         params = torch.load(self.decoder_params_path) |         params = torch.load(self.decoder_params_path) | ||||||
|         self.decoder.load_state_dict(params, strict=False) |         self.decoder.load_state_dict(params, strict=False) | ||||||
|         del params |         del params | ||||||
| @@ -134,7 +135,7 @@ class MinDalle: | |||||||
|         is_downloaded = os.path.exists(self.detoker_params_path) |         is_downloaded = os.path.exists(self.detoker_params_path) | ||||||
|         if not is_downloaded: self.download_detokenizer() |         if not is_downloaded: self.download_detokenizer() | ||||||
|         if self.is_verbose: print("initializing VQGanDetokenizer") |         if self.is_verbose: print("initializing VQGanDetokenizer") | ||||||
|         self.detokenizer = VQGanDetokenizer() |         self.detokenizer = VQGanDetokenizer().to(self.dtype).eval() | ||||||
|         params = torch.load(self.detoker_params_path) |         params = torch.load(self.detoker_params_path) | ||||||
|         self.detokenizer.load_state_dict(params) |         self.detokenizer.load_state_dict(params) | ||||||
|         del params |         del params | ||||||
| @@ -184,12 +185,14 @@ class MinDalle: | |||||||
|  |  | ||||||
|         if not self.is_reusable: self.init_encoder() |         if not self.is_reusable: self.init_encoder() | ||||||
|         if is_verbose: print("encoding text tokens") |         if is_verbose: print("encoding text tokens") | ||||||
|  |         with torch.cuda.amp.autocast(dtype=self.dtype): | ||||||
|             encoder_state = self.encoder.forward(text_tokens) |             encoder_state = self.encoder.forward(text_tokens) | ||||||
|         if not self.is_reusable: del self.encoder |         if not self.is_reusable: del self.encoder | ||||||
|         if torch.cuda.is_available(): torch.cuda.empty_cache() |         if torch.cuda.is_available(): torch.cuda.empty_cache() | ||||||
|  |  | ||||||
|         if not self.is_reusable: self.init_decoder() |         if not self.is_reusable: self.init_decoder() | ||||||
|  |  | ||||||
|  |         with torch.cuda.amp.autocast(dtype=self.dtype): | ||||||
|             encoder_state, attention_mask, attention_state, image_tokens = (  |             encoder_state, attention_mask, attention_state, image_tokens = (  | ||||||
|                 self.decoder.decode_initial( |                 self.decoder.decode_initial( | ||||||
|                     seed,  |                     seed,  | ||||||
| @@ -203,6 +206,7 @@ class MinDalle: | |||||||
|         for row_index in range(row_count): |         for row_index in range(row_count): | ||||||
|             if is_verbose:  |             if is_verbose:  | ||||||
|                 print('sampling row {} of {}'.format(row_index + 1, row_count)) |                 print('sampling row {} of {}'.format(row_index + 1, row_count)) | ||||||
|  |             with torch.cuda.amp.autocast(dtype=self.dtype): | ||||||
|                 attention_state, image_tokens = self.decoder.decode_row( |                 attention_state, image_tokens = self.decoder.decode_row( | ||||||
|                     row_index, |                     row_index, | ||||||
|                     log2_k, |                     log2_k, | ||||||
|   | |||||||
| @@ -40,7 +40,8 @@ class DecoderSelfAttention(AttentionBase): | |||||||
|         queries = self.q_proj.forward(decoder_state) |         queries = self.q_proj.forward(decoder_state) | ||||||
|         attn_mask = self.token_indices < token_index + 1 |         attn_mask = self.token_indices < token_index + 1 | ||||||
|         attn_mask = attn_mask[None][[0] * decoder_state.shape[0]] |         attn_mask = attn_mask[None][[0] * decoder_state.shape[0]] | ||||||
|         attention_state[:, token_index] = torch.cat([keys, values]) |         attn_state_new = torch.cat([keys, values]).to(attention_state.dtype) | ||||||
|  |         attention_state[:, token_index] = attn_state_new | ||||||
|         batch_count = decoder_state.shape[0] |         batch_count = decoder_state.shape[0] | ||||||
|         keys = attention_state[:batch_count] |         keys = attention_state[:batch_count] | ||||||
|         values = attention_state[batch_count:] |         values = attention_state[batch_count:] | ||||||
|   | |||||||
| @@ -82,7 +82,7 @@ class Upsample(Module): | |||||||
|         self.conv = Conv2d(n, n, 3, padding=1) |         self.conv = Conv2d(n, n, 3, padding=1) | ||||||
|  |  | ||||||
|     def forward(self, x: Tensor) -> Tensor: |     def forward(self, x: Tensor) -> Tensor: | ||||||
|         x = self.upsample.forward(x) |         x = self.upsample.forward(x.to(torch.float32)) | ||||||
|         x = self.conv.forward(x) |         x = self.conv.forward(x) | ||||||
|         return x |         return x | ||||||
|  |  | ||||||
|   | |||||||
| @@ -19,9 +19,9 @@ class ReplicatePredictor(BasePredictor): | |||||||
|             default=True |             default=True | ||||||
|         ), |         ), | ||||||
|         grid_size: int = Input( |         grid_size: int = Input( | ||||||
|             description='Size of the image grid', |             description='Size of the image grid.  4x4 takes about 15 seconds, 8x8 takes about 45 seconds', | ||||||
|             ge=1, |             ge=1, | ||||||
|             le=4, |             le=8, | ||||||
|             default=4 |             default=4 | ||||||
|         ), |         ), | ||||||
|         log2_supercondition_factor: int = Input( |         log2_supercondition_factor: int = Input( | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								setup.py
									
									
									
									
									
								
							| @@ -4,8 +4,8 @@ from pathlib import Path | |||||||
| setuptools.setup( | setuptools.setup( | ||||||
|     name='min-dalle', |     name='min-dalle', | ||||||
|     description = 'min(DALL·E)', |     description = 'min(DALL·E)', | ||||||
|     long_description=(Path(__file__).parent / "README.rst").read_text(), |     # long_description=(Path(__file__).parent / "README.rst").read_text(), | ||||||
|     version='0.2.29', |     version='0.2.35', | ||||||
|     author='Brett Kuprel', |     author='Brett Kuprel', | ||||||
|     author_email='brkuprel@gmail.com', |     author_email='brkuprel@gmail.com', | ||||||
|     url='https://github.com/kuprel/min-dalle', |     url='https://github.com/kuprel/min-dalle', | ||||||
| @@ -15,8 +15,8 @@ setuptools.setup( | |||||||
|     ], |     ], | ||||||
|     license='MIT', |     license='MIT', | ||||||
|     install_requires=[ |     install_requires=[ | ||||||
|         'torch>=1.10.0', |         'torch>=1.11', | ||||||
|         'typing_extensions>=4.1.0', |         'typing_extensions>=4.1', | ||||||
|         'numpy>=1.21', |         'numpy>=1.21', | ||||||
|         'pillow>=7.1', |         'pillow>=7.1', | ||||||
|         'requests>=2.23' |         'requests>=2.23' | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user