diff --git a/cog.yaml b/cog.yaml index 6df9711..63ad6a4 100644 --- a/cog.yaml +++ b/cog.yaml @@ -6,7 +6,7 @@ build: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.2.26" + - "min-dalle==0.2.27" run: - pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/cogrun.py b/cogrun.py index 570686e..0697bf2 100644 --- a/cogrun.py +++ b/cogrun.py @@ -34,7 +34,7 @@ class Predictor(BasePredictor): supercondition_factor: int = Input( description='Lower results in a wider variety of images but less agreement with the text', choices=[2, 4, 8, 16, 32, 64], - default=16 + default=8 ), ) -> Iterator[Path]: image_stream = self.model.generate_image_stream( diff --git a/min_dalle.ipynb b/min_dalle.ipynb index df3270f..ad529ed 100644 --- a/min_dalle.ipynb +++ b/min_dalle.ipynb @@ -180,7 +180,7 @@ "grid_size = 3 #@param {type:\"integer\"}\n", "seed = -1 #@param {type:\"integer\"}\n", "intermediate_image_count = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\"] {type:\"raw\"}\n", - "supercondition_factor = 16 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n", + "supercondition_factor = 8 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n", "\n", "image_stream = model.generate_image_stream(\n", " text=text,\n", diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index 5f30ea3..daaeaba 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -164,7 +164,7 @@ class DalleBartDecoder(nn.Module): ) decoder_state = self.final_ln(decoder_state) logits = self.lm_head(decoder_state) - a = log2_supercondition_factor + a = 2 ** log2_supercondition_factor logits: FloatTensor = ( logits[:image_count, -1] * (1 - a) + logits[image_count:, -1] * a diff --git a/setup.py b/setup.py index 9b52d81..5024e6a 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALLĀ·E)', long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.2.26', + version='0.2.27', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',