From d64e957731752f2b27bf4c9d6debb3313e8f8765 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Mon, 11 Jul 2022 12:28:20 -0400 Subject: [PATCH] add temperature parameter --- cog.yaml | 2 +- image_from_text.py | 11 +++- min_dalle/min_dalle.py | 86 ++++++++++++++------------ min_dalle/models/dalle_bart_decoder.py | 28 +++++---- setup.py | 2 +- 5 files changed, 74 insertions(+), 55 deletions(-) diff --git a/cog.yaml b/cog.yaml index 430c32a..a900493 100644 --- a/cog.yaml +++ b/cog.yaml @@ -6,7 +6,7 @@ build: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.3.11" + - "min-dalle==0.3.12" run: - pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/image_from_text.py b/image_from_text.py index 7495bc9..a826012 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -13,6 +13,7 @@ parser.add_argument('--seed', type=int, default=-1) parser.add_argument('--grid-size', type=int, default=1) parser.add_argument('--image-path', type=str, default='generated') parser.add_argument('--models-root', type=str, default='pretrained') +parser.add_argument('--top_k', type=int, default=256) def ascii_from_image(image: Image.Image, size: int = 128) -> str: @@ -38,6 +39,7 @@ def generate_image( text: str, seed: int, grid_size: int, + top_k: int, image_path: str, models_root: str ): @@ -48,7 +50,13 @@ def generate_image( is_verbose=True ) - image = model.generate_image(text, seed, grid_size, is_verbose=True) + image = model.generate_image( + text, + seed, + grid_size, + top_k=top_k, + is_verbose=True + ) save_image(image, image_path) print(ascii_from_image(image, size=128)) @@ -61,6 +69,7 @@ if __name__ == '__main__': text=args.text, seed=args.seed, grid_size=args.grid_size, + top_k=args.top_k, image_path=args.image_path, models_root=args.models_root ) \ No newline at end of file diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index ca51f1a..26981e6 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -177,8 +177,9 @@ class MinDalle: seed: int, image_count: int, log2_mid_count: int, - log2_k: int = 6, - log2_supercondition_factor: int = 3, + temperature: float = 1, + top_k: int = 256, + supercondition_factor: int = 16, is_verbose: bool = False ) -> Iterator[FloatTensor]: assert(log2_mid_count in range(5)) @@ -206,10 +207,10 @@ class MinDalle: with torch.cuda.amp.autocast(dtype=self.dtype): encoder_state, attention_mask, attention_state, image_tokens = ( self.decoder.decode_initial( - seed, - image_count, - text_tokens, - encoder_state + seed=seed, + image_count=image_count, + text_tokens=text_tokens, + encoder_state=encoder_state ) ) @@ -220,12 +221,13 @@ class MinDalle: with torch.cuda.amp.autocast(dtype=self.dtype): attention_state, image_tokens = self.decoder.decode_row( row_index, - log2_k, - log2_supercondition_factor, - encoder_state, - attention_mask, - attention_state, - image_tokens + temperature=temperature, + top_k=top_k, + supercondition_factor=supercondition_factor, + encoder_state=encoder_state, + attention_mask=attention_mask, + attention_state=attention_state, + image_tokens_sequence=image_tokens ) with torch.cuda.amp.autocast(dtype=torch.float32): if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: @@ -240,18 +242,20 @@ class MinDalle: seed: int, grid_size: int, log2_mid_count: int, - log2_k: int = 6, - log2_supercondition_factor: int = 3, + temperature: float = 1, + top_k: int = 256, + supercondition_factor: int = 16, is_verbose: bool = False ) -> Iterator[Image.Image]: images_stream = self.generate_images_stream( - text, - seed, - grid_size ** 2, - log2_mid_count, - log2_k, - log2_supercondition_factor, - is_verbose + text=text, + seed=seed, + image_count=grid_size ** 2, + log2_mid_count=log2_mid_count, + temperature=temperature, + top_k=top_k, + supercondition_factor=supercondition_factor, + is_verbose=is_verbose ) for images in images_stream: yield self.grid_from_images(images) @@ -262,19 +266,21 @@ class MinDalle: text: str, seed: int = -1, image_count: int = 1, - log2_k: int = 6, - log2_supercondition_factor: int = 3, + temperature: float = 1, + top_k: int = 1024, + supercondition_factor: int = 16, is_verbose: bool = False ) -> FloatTensor: log2_mid_count = 0 images_stream = self.generate_images_stream( - text, - seed, - image_count, - log2_mid_count, - log2_k, - log2_supercondition_factor, - is_verbose + text=text, + seed=seed, + image_count=image_count, + temperature=temperature, + log2_mid_count=log2_mid_count, + top_k=top_k, + supercondition_factor=supercondition_factor, + is_verbose=is_verbose ) return next(images_stream) @@ -284,18 +290,20 @@ class MinDalle: text: str, seed: int = -1, grid_size: int = 1, - log2_k: int = 6, - log2_supercondition_factor: int = 3, + temperature: float = 1, + top_k: int = 1024, + supercondition_factor: int = 16, is_verbose: bool = False ) -> Image.Image: log2_mid_count = 0 image_stream = self.generate_image_stream( - text, - seed, - grid_size, - log2_mid_count, - log2_k, - log2_supercondition_factor, - is_verbose + text=text, + seed=seed, + grid_size=grid_size, + log2_mid_count=log2_mid_count, + temperature=temperature, + top_k=top_k, + supercondition_factor=supercondition_factor, + is_verbose=is_verbose ) return next(image_stream) \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index 1dcee20..31b0154 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -140,8 +140,9 @@ class DalleBartDecoder(nn.Module): def decode_step( self, - log2_k: int, - log2_supercondition_factor: int, + temperature: float, + top_k: int, + supercondition_factor: int, attention_mask: BoolTensor, encoder_state: FloatTensor, attention_state: FloatTensor, @@ -166,18 +167,17 @@ class DalleBartDecoder(nn.Module): ) decoder_state = self.final_ln(decoder_state) logits = self.lm_head(decoder_state) - a = 2 ** log2_supercondition_factor + a = supercondition_factor logits: FloatTensor = ( logits[:image_count, -1] * (1 - a) + logits[image_count:, -1] * a ) - top_logits, _ = logits.topk(2 ** log2_k, dim=-1) - probs = torch.where( - logits < top_logits[:, [-1]], - self.zero_prob, - torch.exp(logits - top_logits[:, [0]]) - ) + top_logits, _ = logits.topk(top_k, dim=-1) + is_kept = logits >= top_logits[:, [-1]] + logits -= top_logits[:, [0]] + logits /= max(temperature, 1e-6) + probs = torch.where(is_kept, torch.exp(logits), self.zero_prob) probs[:, 2 ** 14:] = 0 # vqgan vocab_count is only 2 ** 14 return probs, attention_state @@ -185,8 +185,9 @@ class DalleBartDecoder(nn.Module): def decode_row( self, row_index: int, - log2_k: int, - log2_supercondition_factor: int, + temperature: float, + top_k: int, + supercondition_factor: int, encoder_state: FloatTensor, attention_mask: BoolTensor, attention_state: FloatTensor, @@ -195,8 +196,9 @@ class DalleBartDecoder(nn.Module): for col_index in range(16): i = 16 * row_index + col_index probs, attention_state = self.decode_step( - log2_k = log2_k, - log2_supercondition_factor = log2_supercondition_factor, + temperature = temperature, + top_k = top_k, + supercondition_factor = supercondition_factor, attention_mask = attention_mask, encoder_state = encoder_state, attention_state = attention_state, diff --git a/setup.py b/setup.py index e248c52..f67f52d 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALLĀ·E)', # long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.3.11', + version='0.3.12', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',