vqgan needs to be float32
This commit is contained in:
parent
da62298f06
commit
c199507a7a
2
README.md
vendored
2
README.md
vendored
|
@ -36,7 +36,7 @@ model = MinDalle(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `dtype=torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens.
|
The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
image = model.generate_image(
|
image = model.generate_image(
|
||||||
|
|
6
cog.yaml
vendored
6
cog.yaml
vendored
|
@ -1,12 +1,12 @@
|
||||||
build:
|
build:
|
||||||
cuda: "11.3"
|
cuda: "11.5.1"
|
||||||
gpu: true
|
gpu: true
|
||||||
python_version: "3.8"
|
python_version: "3.10"
|
||||||
system_packages:
|
system_packages:
|
||||||
- "libgl1-mesa-glx"
|
- "libgl1-mesa-glx"
|
||||||
- "libglib2.0-0"
|
- "libglib2.0-0"
|
||||||
python_packages:
|
python_packages:
|
||||||
- "min-dalle==0.2.35"
|
- "min-dalle==0.2.36"
|
||||||
run:
|
run:
|
||||||
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
|
||||||
|
|
|
@ -135,7 +135,7 @@ class MinDalle:
|
||||||
is_downloaded = os.path.exists(self.detoker_params_path)
|
is_downloaded = os.path.exists(self.detoker_params_path)
|
||||||
if not is_downloaded: self.download_detokenizer()
|
if not is_downloaded: self.download_detokenizer()
|
||||||
if self.is_verbose: print("initializing VQGanDetokenizer")
|
if self.is_verbose: print("initializing VQGanDetokenizer")
|
||||||
self.detokenizer = VQGanDetokenizer().to(self.dtype).eval()
|
self.detokenizer = VQGanDetokenizer().eval()
|
||||||
params = torch.load(self.detoker_params_path)
|
params = torch.load(self.detoker_params_path)
|
||||||
self.detokenizer.load_state_dict(params)
|
self.detokenizer.load_state_dict(params)
|
||||||
del params
|
del params
|
||||||
|
@ -216,6 +216,7 @@ class MinDalle:
|
||||||
attention_state,
|
attention_state,
|
||||||
image_tokens
|
image_tokens
|
||||||
)
|
)
|
||||||
|
with torch.cuda.amp.autocast(dtype=torch.float32):
|
||||||
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
|
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
|
||||||
tokens = image_tokens[:, 1:]
|
tokens = image_tokens[:, 1:]
|
||||||
image = self.image_from_tokens(grid_size, tokens, is_verbose)
|
image = self.image_from_tokens(grid_size, tokens, is_verbose)
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -5,7 +5,7 @@ setuptools.setup(
|
||||||
name='min-dalle',
|
name='min-dalle',
|
||||||
description = 'min(DALL·E)',
|
description = 'min(DALL·E)',
|
||||||
# long_description=(Path(__file__).parent / "README.rst").read_text(),
|
# long_description=(Path(__file__).parent / "README.rst").read_text(),
|
||||||
version='0.2.35',
|
version='0.2.36',
|
||||||
author='Brett Kuprel',
|
author='Brett Kuprel',
|
||||||
author_email='brkuprel@gmail.com',
|
author_email='brkuprel@gmail.com',
|
||||||
url='https://github.com/kuprel/min-dalle',
|
url='https://github.com/kuprel/min-dalle',
|
||||||
|
|
Loading…
Reference in New Issue
Block a user