From 89a125b4b9462239ba75a2c1f81d8e9fa84b2225 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Tue, 5 Jul 2022 17:23:05 -0400 Subject: [PATCH] control top_k value --- cog.yaml | 8 ++++---- min_dalle/min_dalle.py | 4 ++++ min_dalle/models/dalle_bart_decoder.py | 5 ++++- setup.py | 2 +- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/cog.yaml b/cog.yaml index a37c069..371a187 100644 --- a/cog.yaml +++ b/cog.yaml @@ -1,13 +1,13 @@ build: - cuda: "11.5.1" + cuda: "11.0" gpu: true - python_version: "3.10" + python_version: "3.8" system_packages: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.2.27" + - "min-dalle==0.2.28" run: - - pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html + - pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html predict: "replicate_predictor.py:ReplicatePredictor" diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index 3f23878..3320b74 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -165,6 +165,7 @@ class MinDalle: seed: int, grid_size: int, log2_mid_count: int, + log2_k: int = 6, log2_supercondition_factor: int = 3, is_verbose: bool = False ) -> Iterator[Image.Image]: @@ -202,6 +203,7 @@ class MinDalle: print('sampling row {} of {}'.format(row_index + 1, row_count)) attention_state, image_tokens = self.decoder.decode_row( row_index, + log2_k, log2_supercondition_factor, encoder_state, attention_mask, @@ -219,6 +221,7 @@ class MinDalle: text: str, seed: int = -1, grid_size: int = 1, + log2_k: int = 6, log2_supercondition_factor: int = 3, is_verbose: bool = False ) -> Image.Image: @@ -228,6 +231,7 @@ class MinDalle: seed, grid_size, log2_mid_count, + log2_k, log2_supercondition_factor, is_verbose ) diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index daaeaba..393618a 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -140,6 +140,7 @@ class DalleBartDecoder(nn.Module): def decode_step( self, + log2_k: int, log2_supercondition_factor: int, attention_mask: BoolTensor, encoder_state: FloatTensor, @@ -170,7 +171,7 @@ class DalleBartDecoder(nn.Module): logits[image_count:, -1] * a ) - top_logits, _ = logits.topk(50, dim=-1) + top_logits, _ = logits.topk(2 ** log2_k, dim=-1) probs = torch.where( logits < top_logits[:, [-1]], self.zero_prob, @@ -182,6 +183,7 @@ class DalleBartDecoder(nn.Module): def decode_row( self, row_index: int, + log2_k: int, log2_supercondition_factor: int, encoder_state: FloatTensor, attention_mask: BoolTensor, @@ -191,6 +193,7 @@ class DalleBartDecoder(nn.Module): for col_index in range(16): i = 16 * row_index + col_index probs, attention_state = self.decode_step( + log2_k = log2_k, log2_supercondition_factor = log2_supercondition_factor, attention_mask = attention_mask, encoder_state = encoder_state, diff --git a/setup.py b/setup.py index 5024e6a..4fe58cc 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALLĀ·E)', long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.2.27', + version='0.2.28', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',