Merge pull request #38 from chenxwh/replicate
Add Web Demo & Docker environment
This commit is contained in:
		
							
								
								
									
										3
									
								
								README.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								README.md
									
									
									
									
										vendored
									
									
								
							| @@ -1,6 +1,7 @@ | |||||||
| # min(DALL·E) | # min(DALL·E) | ||||||
|  |  | ||||||
| [](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) | [](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) \ | ||||||
|  | Try Replicate web demo here [](https://replicate.com/kuprel/min-dalle) | ||||||
|  |  | ||||||
| This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini).  It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch.  The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).   | This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini).  It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch.  The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).   | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										14
									
								
								cog.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								cog.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | |||||||
|  | build: | ||||||
|  |   cuda: "11.0" | ||||||
|  |   gpu: true | ||||||
|  |   python_version: "3.8" | ||||||
|  |   system_packages: | ||||||
|  |     - "libgl1-mesa-glx" | ||||||
|  |     - "libglib2.0-0" | ||||||
|  |   python_packages: | ||||||
|  |     - "ipython==7.21.0" | ||||||
|  |     - "torch==1.10.1" | ||||||
|  |     - "flax==0.4.2" | ||||||
|  |     - "wandb==0.12.16" | ||||||
|  |  | ||||||
|  | predict: "predict.py:Predictor" | ||||||
							
								
								
									
										53
									
								
								predict.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								predict.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | |||||||
|  | import tempfile | ||||||
|  | from PIL import Image | ||||||
|  | from cog import BasePredictor, Path, Input | ||||||
|  |  | ||||||
|  | from min_dalle.generate_image import load_dalle_bart_metadata, tokenize_text | ||||||
|  | from min_dalle.load_params import load_dalle_bart_flax_params | ||||||
|  | from min_dalle.min_dalle_torch import generate_image_tokens_torch, detokenize_torch | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Predictor(BasePredictor): | ||||||
|  |     def setup(self): | ||||||
|  |         self.model_path = { | ||||||
|  |             "mini": "pretrained/dalle_bart_mini", | ||||||
|  |             "mega": "pretrained/dalle_bart_mega", | ||||||
|  |         } | ||||||
|  |         self.configs = { | ||||||
|  |             k: load_dalle_bart_metadata(self.model_path[k]) | ||||||
|  |             for k in self.model_path.keys() | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |     def predict( | ||||||
|  |         self, | ||||||
|  |         text: str = Input( | ||||||
|  |             description="Text for generating images.", | ||||||
|  |         ), | ||||||
|  |         model: str = Input( | ||||||
|  |             choices=["mini", "mega"], | ||||||
|  |             description="Choose mini or mega model.", | ||||||
|  |         ), | ||||||
|  |         seed: int = Input( | ||||||
|  |             description="Specify the seed.", | ||||||
|  |         ), | ||||||
|  |     ) -> Path: | ||||||
|  |  | ||||||
|  |         config, vocab, merges = self.configs[model] | ||||||
|  |         text_tokens = tokenize_text(text, config, vocab, merges) | ||||||
|  |         params_dalle_bart = load_dalle_bart_flax_params(self.model_path[model]) | ||||||
|  |  | ||||||
|  |         image_token_count = config["image_length"] | ||||||
|  |         image_tokens = generate_image_tokens_torch( | ||||||
|  |             text_tokens=text_tokens, | ||||||
|  |             seed=seed, | ||||||
|  |             config=config, | ||||||
|  |             params=params_dalle_bart, | ||||||
|  |             image_token_count=image_token_count, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         image = detokenize_torch(image_tokens, is_torch=True) | ||||||
|  |         image = Image.fromarray(image) | ||||||
|  |         out_path = Path(tempfile.mkdtemp()) / "output.png" | ||||||
|  |         image.save(str(out_path)) | ||||||
|  |  | ||||||
|  |         return out_path | ||||||
		Reference in New Issue
	
	Block a user