From c199507a7a2a827dffd651b82c8da60260ee1c44 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Thu, 7 Jul 2022 08:53:27 -0400 Subject: [PATCH] vqgan needs to be float32 --- README.md | 2 +- cog.yaml | 6 +++--- min_dalle/min_dalle.py | 3 ++- setup.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index a166ca1..2223661 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ model = MinDalle( ) ``` -The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `dtype=torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens. +The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens. ```python image = model.generate_image( diff --git a/cog.yaml b/cog.yaml index 7a59bb0..c0551d5 100644 --- a/cog.yaml +++ b/cog.yaml @@ -1,12 +1,12 @@ build: - cuda: "11.3" + cuda: "11.5.1" gpu: true - python_version: "3.8" + python_version: "3.10" system_packages: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.2.35" + - "min-dalle==0.2.36" run: - pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index 772ed94..152146b 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -135,7 +135,7 @@ class MinDalle: is_downloaded = os.path.exists(self.detoker_params_path) if not is_downloaded: self.download_detokenizer() if self.is_verbose: print("initializing VQGanDetokenizer") - self.detokenizer = VQGanDetokenizer().to(self.dtype).eval() + self.detokenizer = VQGanDetokenizer().eval() params = torch.load(self.detoker_params_path) self.detokenizer.load_state_dict(params) del params @@ -216,6 +216,7 @@ class MinDalle: attention_state, image_tokens ) + with torch.cuda.amp.autocast(dtype=torch.float32): if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: tokens = image_tokens[:, 1:] image = self.image_from_tokens(grid_size, tokens, is_verbose) diff --git a/setup.py b/setup.py index 790236f..ae817f2 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALLĀ·E)', # long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.2.35', + version='0.2.36', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',