forgot missing 2**
This commit is contained in:
		
							
								
								
									
										2
									
								
								cog.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								cog.yaml
									
									
									
									
										vendored
									
									
								
							| @@ -6,7 +6,7 @@ build: | |||||||
|     - "libgl1-mesa-glx" |     - "libgl1-mesa-glx" | ||||||
|     - "libglib2.0-0" |     - "libglib2.0-0" | ||||||
|   python_packages: |   python_packages: | ||||||
|     - "min-dalle==0.2.26" |     - "min-dalle==0.2.27" | ||||||
|   run: |   run: | ||||||
|     - pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html |     - pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html | ||||||
|  |  | ||||||
|   | |||||||
| @@ -34,7 +34,7 @@ class Predictor(BasePredictor): | |||||||
|         supercondition_factor: int = Input( |         supercondition_factor: int = Input( | ||||||
|             description='Lower results in a wider variety of images but less agreement with the text', |             description='Lower results in a wider variety of images but less agreement with the text', | ||||||
|             choices=[2, 4, 8, 16, 32, 64], |             choices=[2, 4, 8, 16, 32, 64], | ||||||
|             default=16 |             default=8 | ||||||
|         ), |         ), | ||||||
|     ) -> Iterator[Path]: |     ) -> Iterator[Path]: | ||||||
|         image_stream = self.model.generate_image_stream( |         image_stream = self.model.generate_image_stream( | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								min_dalle.ipynb
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								min_dalle.ipynb
									
									
									
									
										vendored
									
									
								
							| @@ -180,7 +180,7 @@ | |||||||
|         "grid_size = 3 #@param {type:\"integer\"}\n", |         "grid_size = 3 #@param {type:\"integer\"}\n", | ||||||
|         "seed = -1 #@param {type:\"integer\"}\n", |         "seed = -1 #@param {type:\"integer\"}\n", | ||||||
|         "intermediate_image_count = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\"] {type:\"raw\"}\n", |         "intermediate_image_count = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\"] {type:\"raw\"}\n", | ||||||
|         "supercondition_factor = 16 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n", |         "supercondition_factor = 8 #@param [\"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n", | ||||||
|         "\n", |         "\n", | ||||||
|         "image_stream = model.generate_image_stream(\n", |         "image_stream = model.generate_image_stream(\n", | ||||||
|         "    text=text,\n", |         "    text=text,\n", | ||||||
|   | |||||||
| @@ -164,7 +164,7 @@ class DalleBartDecoder(nn.Module): | |||||||
|             ) |             ) | ||||||
|         decoder_state = self.final_ln(decoder_state) |         decoder_state = self.final_ln(decoder_state) | ||||||
|         logits = self.lm_head(decoder_state) |         logits = self.lm_head(decoder_state) | ||||||
|         a = log2_supercondition_factor |         a = 2 ** log2_supercondition_factor | ||||||
|         logits: FloatTensor = ( |         logits: FloatTensor = ( | ||||||
|             logits[:image_count, -1] * (1 - a) +  |             logits[:image_count, -1] * (1 - a) +  | ||||||
|             logits[image_count:, -1] * a |             logits[image_count:, -1] * a | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @@ -5,7 +5,7 @@ setuptools.setup( | |||||||
|     name='min-dalle', |     name='min-dalle', | ||||||
|     description = 'min(DALL·E)', |     description = 'min(DALL·E)', | ||||||
|     long_description=(Path(__file__).parent / "README.rst").read_text(), |     long_description=(Path(__file__).parent / "README.rst").read_text(), | ||||||
|     version='0.2.26', |     version='0.2.27', | ||||||
|     author='Brett Kuprel', |     author='Brett Kuprel', | ||||||
|     author_email='brkuprel@gmail.com', |     author_email='brkuprel@gmail.com', | ||||||
|     url='https://github.com/kuprel/min-dalle', |     url='https://github.com/kuprel/min-dalle', | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user