vqgan needs to be float32
This commit is contained in:
parent
da62298f06
commit
c199507a7a
2
README.md
vendored
2
README.md
vendored
|
@ -36,7 +36,7 @@ model = MinDalle(
|
|||
)
|
||||
```
|
||||
|
||||
The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `dtype=torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens.
|
||||
The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens.
|
||||
|
||||
```python
|
||||
image = model.generate_image(
|
||||
|
|
6
cog.yaml
vendored
6
cog.yaml
vendored
|
@ -1,12 +1,12 @@
|
|||
build:
|
||||
cuda: "11.3"
|
||||
cuda: "11.5.1"
|
||||
gpu: true
|
||||
python_version: "3.8"
|
||||
python_version: "3.10"
|
||||
system_packages:
|
||||
- "libgl1-mesa-glx"
|
||||
- "libglib2.0-0"
|
||||
python_packages:
|
||||
- "min-dalle==0.2.35"
|
||||
- "min-dalle==0.2.36"
|
||||
run:
|
||||
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
|
|
|
@ -135,7 +135,7 @@ class MinDalle:
|
|||
is_downloaded = os.path.exists(self.detoker_params_path)
|
||||
if not is_downloaded: self.download_detokenizer()
|
||||
if self.is_verbose: print("initializing VQGanDetokenizer")
|
||||
self.detokenizer = VQGanDetokenizer().to(self.dtype).eval()
|
||||
self.detokenizer = VQGanDetokenizer().eval()
|
||||
params = torch.load(self.detoker_params_path)
|
||||
self.detokenizer.load_state_dict(params)
|
||||
del params
|
||||
|
@ -216,6 +216,7 @@ class MinDalle:
|
|||
attention_state,
|
||||
image_tokens
|
||||
)
|
||||
with torch.cuda.amp.autocast(dtype=torch.float32):
|
||||
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
|
||||
tokens = image_tokens[:, 1:]
|
||||
image = self.image_from_tokens(grid_size, tokens, is_verbose)
|
||||
|
|
2
setup.py
2
setup.py
|
@ -5,7 +5,7 @@ setuptools.setup(
|
|||
name='min-dalle',
|
||||
description = 'min(DALL·E)',
|
||||
# long_description=(Path(__file__).parent / "README.rst").read_text(),
|
||||
version='0.2.35',
|
||||
version='0.2.36',
|
||||
author='Brett Kuprel',
|
||||
author_email='brkuprel@gmail.com',
|
||||
url='https://github.com/kuprel/min-dalle',
|
||||
|
|
Loading…
Reference in New Issue
Block a user