diff --git a/README.md b/README.md index d735762..17f25b1 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # min(DALL·E) -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) \ +Try Replicate web demo here [![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle) This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models). diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..269e8bf --- /dev/null +++ b/cog.yaml @@ -0,0 +1,14 @@ +build: + cuda: "11.0" + gpu: true + python_version: "3.8" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_packages: + - "ipython==7.21.0" + - "torch==1.10.1" + - "flax==0.4.2" + - "wandb==0.12.16" + +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..f8c1e3d --- /dev/null +++ b/predict.py @@ -0,0 +1,53 @@ +import tempfile +from PIL import Image +from cog import BasePredictor, Path, Input + +from min_dalle.generate_image import load_dalle_bart_metadata, tokenize_text +from min_dalle.load_params import load_dalle_bart_flax_params +from min_dalle.min_dalle_torch import generate_image_tokens_torch, detokenize_torch + + +class Predictor(BasePredictor): + def setup(self): + self.model_path = { + "mini": "pretrained/dalle_bart_mini", + "mega": "pretrained/dalle_bart_mega", + } + self.configs = { + k: load_dalle_bart_metadata(self.model_path[k]) + for k in self.model_path.keys() + } + + def predict( + self, + text: str = Input( + description="Text for generating images.", + ), + model: str = Input( + choices=["mini", "mega"], + description="Choose mini or mega model.", + ), + seed: int = Input( + description="Specify the seed.", + ), + ) -> Path: + + config, vocab, merges = self.configs[model] + text_tokens = tokenize_text(text, config, vocab, merges) + params_dalle_bart = load_dalle_bart_flax_params(self.model_path[model]) + + image_token_count = config["image_length"] + image_tokens = generate_image_tokens_torch( + text_tokens=text_tokens, + seed=seed, + config=config, + params=params_dalle_bart, + image_token_count=image_token_count, + ) + + image = detokenize_torch(image_tokens, is_torch=True) + image = Image.fromarray(image) + out_path = Path(tempfile.mkdtemp()) / "output.png" + image.save(str(out_path)) + + return out_path