diff --git a/cog.yaml b/cog.yaml index ad2cda4..6df9711 100644 --- a/cog.yaml +++ b/cog.yaml @@ -6,7 +6,7 @@ build: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.2.24" + - "min-dalle==0.2.26" run: - pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/cogrun.py b/cogrun.py index da258f8..e437b82 100644 --- a/cogrun.py +++ b/cogrun.py @@ -1,3 +1,4 @@ +from contextlib import suppress from min_dalle import MinDalle import tempfile from typing import Iterator @@ -30,12 +31,18 @@ class Predictor(BasePredictor): choices=[1, 2, 4, 8, 16], default=8 ), + supercondition_factor: int = Input( + description='Lower results in a wider variety of images but less agreement with the text', + choices=[2, 4, 8, 16, 32, 64], + default=8 + ), ) -> Iterator[Path]: image_stream = self.model.generate_image_stream( text, seed, grid_size=grid_size, log2_mid_count=log2(intermediate_image_count), + log2_supercondition_factor=log2(supercondition_factor), is_verbose=True ) diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index 46bd1b8..4f146e8 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -164,7 +164,8 @@ class MinDalle: text: str, seed: int, grid_size: int, - log2_mid_count: int = 0, + log2_mid_count: int, + log2_supercondition_factor: int = 3, is_verbose: bool = False ) -> Iterator[Image.Image]: if is_verbose: print("tokenizing text") @@ -200,6 +201,7 @@ class MinDalle: print('sampling row {} of {}'.format(row_index + 1, row_count)) attention_state, image_tokens = self.decoder.decode_row( row_index, + log2_supercondition_factor, encoder_state, attention_mask, attention_state, @@ -216,6 +218,7 @@ class MinDalle: text: str, seed: int = -1, grid_size: int = 1, + log2_supercondition_factor: int = 3, is_verbose: bool = False ) -> Image.Image: log2_mid_count = 0 @@ -224,6 +227,7 @@ class MinDalle: seed, grid_size, log2_mid_count, + log2_supercondition_factor, is_verbose ) return next(image_stream) \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index dcbc680..5f30ea3 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -116,7 +116,6 @@ class DalleBartDecoder(nn.Module): super().__init__() self.layer_count = layer_count self.embed_count = embed_count - self.condition_factor = 10.0 self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count) self.layers: List[DecoderLayer] = nn.ModuleList([ @@ -141,6 +140,7 @@ class DalleBartDecoder(nn.Module): def decode_step( self, + log2_supercondition_factor: int, attention_mask: BoolTensor, encoder_state: FloatTensor, attention_state: FloatTensor, @@ -164,7 +164,7 @@ class DalleBartDecoder(nn.Module): ) decoder_state = self.final_ln(decoder_state) logits = self.lm_head(decoder_state) - a = self.condition_factor + a = log2_supercondition_factor logits: FloatTensor = ( logits[:image_count, -1] * (1 - a) + logits[image_count:, -1] * a @@ -182,6 +182,7 @@ class DalleBartDecoder(nn.Module): def decode_row( self, row_index: int, + log2_supercondition_factor: int, encoder_state: FloatTensor, attention_mask: BoolTensor, attention_state: FloatTensor, @@ -190,6 +191,7 @@ class DalleBartDecoder(nn.Module): for col_index in range(16): i = 16 * row_index + col_index probs, attention_state = self.decode_step( + log2_supercondition_factor = log2_supercondition_factor, attention_mask = attention_mask, encoder_state = encoder_state, attention_state = attention_state, diff --git a/setup.py b/setup.py index e6baede..9b52d81 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALLĀ·E)', long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.2.24', + version='0.2.26', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',