control top_k value
This commit is contained in:
		
							
								
								
									
										8
									
								
								cog.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								cog.yaml
									
									
									
									
										vendored
									
									
								
							| @@ -1,13 +1,13 @@ | ||||
| build: | ||||
|   cuda: "11.5.1" | ||||
|   cuda: "11.0" | ||||
|   gpu: true | ||||
|   python_version: "3.10" | ||||
|   python_version: "3.8" | ||||
|   system_packages: | ||||
|     - "libgl1-mesa-glx" | ||||
|     - "libglib2.0-0" | ||||
|   python_packages: | ||||
|     - "min-dalle==0.2.27" | ||||
|     - "min-dalle==0.2.28" | ||||
|   run: | ||||
|     - pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html | ||||
|     - pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html | ||||
|  | ||||
| predict: "replicate_predictor.py:ReplicatePredictor" | ||||
|   | ||||
| @@ -165,6 +165,7 @@ class MinDalle: | ||||
|         seed: int, | ||||
|         grid_size: int, | ||||
|         log2_mid_count: int, | ||||
|         log2_k: int = 6, | ||||
|         log2_supercondition_factor: int = 3, | ||||
|         is_verbose: bool = False | ||||
|     ) -> Iterator[Image.Image]: | ||||
| @@ -202,6 +203,7 @@ class MinDalle: | ||||
|                 print('sampling row {} of {}'.format(row_index + 1, row_count)) | ||||
|             attention_state, image_tokens = self.decoder.decode_row( | ||||
|                 row_index, | ||||
|                 log2_k, | ||||
|                 log2_supercondition_factor, | ||||
|                 encoder_state, | ||||
|                 attention_mask, | ||||
| @@ -219,6 +221,7 @@ class MinDalle: | ||||
|         text: str, | ||||
|         seed: int = -1, | ||||
|         grid_size: int = 1, | ||||
|         log2_k: int = 6, | ||||
|         log2_supercondition_factor: int = 3, | ||||
|         is_verbose: bool = False | ||||
|     ) -> Image.Image: | ||||
| @@ -228,6 +231,7 @@ class MinDalle: | ||||
|             seed, | ||||
|             grid_size, | ||||
|             log2_mid_count, | ||||
|             log2_k, | ||||
|             log2_supercondition_factor, | ||||
|             is_verbose | ||||
|         ) | ||||
|   | ||||
| @@ -140,6 +140,7 @@ class DalleBartDecoder(nn.Module): | ||||
|  | ||||
|     def decode_step( | ||||
|         self, | ||||
|         log2_k: int, | ||||
|         log2_supercondition_factor: int, | ||||
|         attention_mask: BoolTensor, | ||||
|         encoder_state: FloatTensor, | ||||
| @@ -170,7 +171,7 @@ class DalleBartDecoder(nn.Module): | ||||
|             logits[image_count:, -1] * a | ||||
|         ) | ||||
|  | ||||
|         top_logits, _ = logits.topk(50, dim=-1) | ||||
|         top_logits, _ = logits.topk(2 ** log2_k, dim=-1) | ||||
|         probs = torch.where( | ||||
|             logits < top_logits[:, [-1]], | ||||
|             self.zero_prob, | ||||
| @@ -182,6 +183,7 @@ class DalleBartDecoder(nn.Module): | ||||
|     def decode_row( | ||||
|         self, | ||||
|         row_index: int, | ||||
|         log2_k: int, | ||||
|         log2_supercondition_factor: int, | ||||
|         encoder_state: FloatTensor, | ||||
|         attention_mask: BoolTensor, | ||||
| @@ -191,6 +193,7 @@ class DalleBartDecoder(nn.Module): | ||||
|         for col_index in range(16): | ||||
|             i = 16 * row_index + col_index | ||||
|             probs, attention_state = self.decode_step( | ||||
|                 log2_k = log2_k, | ||||
|                 log2_supercondition_factor = log2_supercondition_factor, | ||||
|                 attention_mask = attention_mask, | ||||
|                 encoder_state = encoder_state, | ||||
|   | ||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @@ -5,7 +5,7 @@ setuptools.setup( | ||||
|     name='min-dalle', | ||||
|     description = 'min(DALL·E)', | ||||
|     long_description=(Path(__file__).parent / "README.rst").read_text(), | ||||
|     version='0.2.27', | ||||
|     version='0.2.28', | ||||
|     author='Brett Kuprel', | ||||
|     author_email='brkuprel@gmail.com', | ||||
|     url='https://github.com/kuprel/min-dalle', | ||||
|   | ||||
		Reference in New Issue
	
	Block a user