vqgan needs to be float32

This commit is contained in:
Brett Kuprel 2022-07-07 08:53:27 -04:00
parent da62298f06
commit c199507a7a
4 changed files with 7 additions and 6 deletions

2
README.md vendored
View File

@ -36,7 +36,7 @@ model = MinDalle(
) )
``` ```
The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `dtype=torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens. The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens.
```python ```python
image = model.generate_image( image = model.generate_image(

6
cog.yaml vendored
View File

@ -1,12 +1,12 @@
build: build:
cuda: "11.3" cuda: "11.5.1"
gpu: true gpu: true
python_version: "3.8" python_version: "3.10"
system_packages: system_packages:
- "libgl1-mesa-glx" - "libgl1-mesa-glx"
- "libglib2.0-0" - "libglib2.0-0"
python_packages: python_packages:
- "min-dalle==0.2.35" - "min-dalle==0.2.36"
run: run:
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html - pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html

View File

@ -135,7 +135,7 @@ class MinDalle:
is_downloaded = os.path.exists(self.detoker_params_path) is_downloaded = os.path.exists(self.detoker_params_path)
if not is_downloaded: self.download_detokenizer() if not is_downloaded: self.download_detokenizer()
if self.is_verbose: print("initializing VQGanDetokenizer") if self.is_verbose: print("initializing VQGanDetokenizer")
self.detokenizer = VQGanDetokenizer().to(self.dtype).eval() self.detokenizer = VQGanDetokenizer().eval()
params = torch.load(self.detoker_params_path) params = torch.load(self.detoker_params_path)
self.detokenizer.load_state_dict(params) self.detokenizer.load_state_dict(params)
del params del params
@ -216,6 +216,7 @@ class MinDalle:
attention_state, attention_state,
image_tokens image_tokens
) )
with torch.cuda.amp.autocast(dtype=torch.float32):
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
tokens = image_tokens[:, 1:] tokens = image_tokens[:, 1:]
image = self.image_from_tokens(grid_size, tokens, is_verbose) image = self.image_from_tokens(grid_size, tokens, is_verbose)

View File

@ -5,7 +5,7 @@ setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
# long_description=(Path(__file__).parent / "README.rst").read_text(), # long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.2.35', version='0.2.36',
author='Brett Kuprel', author='Brett Kuprel',
author_email='brkuprel@gmail.com', author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle', url='https://github.com/kuprel/min-dalle',