Merge pull request #38 from chenxwh/replicate

Add Web Demo & Docker environment
This commit is contained in:
Brett Kuprel 2022-06-29 15:20:20 -04:00 committed by GitHub
commit 15b0c03485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 1 deletions

3
README.md vendored
View File

@ -1,6 +1,7 @@
# min(DALL·E) # min(DALL·E)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) \
Try Replicate web demo here [![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models). This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).

14
cog.yaml vendored Normal file
View File

@ -0,0 +1,14 @@
build:
cuda: "11.0"
gpu: true
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "ipython==7.21.0"
- "torch==1.10.1"
- "flax==0.4.2"
- "wandb==0.12.16"
predict: "predict.py:Predictor"

53
predict.py Normal file
View File

@ -0,0 +1,53 @@
import tempfile
from PIL import Image
from cog import BasePredictor, Path, Input
from min_dalle.generate_image import load_dalle_bart_metadata, tokenize_text
from min_dalle.load_params import load_dalle_bart_flax_params
from min_dalle.min_dalle_torch import generate_image_tokens_torch, detokenize_torch
class Predictor(BasePredictor):
def setup(self):
self.model_path = {
"mini": "pretrained/dalle_bart_mini",
"mega": "pretrained/dalle_bart_mega",
}
self.configs = {
k: load_dalle_bart_metadata(self.model_path[k])
for k in self.model_path.keys()
}
def predict(
self,
text: str = Input(
description="Text for generating images.",
),
model: str = Input(
choices=["mini", "mega"],
description="Choose mini or mega model.",
),
seed: int = Input(
description="Specify the seed.",
),
) -> Path:
config, vocab, merges = self.configs[model]
text_tokens = tokenize_text(text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(self.model_path[model])
image_token_count = config["image_length"]
image_tokens = generate_image_tokens_torch(
text_tokens=text_tokens,
seed=seed,
config=config,
params=params_dalle_bart,
image_token_count=image_token_count,
)
image = detokenize_torch(image_tokens, is_torch=True)
image = Image.fromarray(image)
out_path = Path(tempfile.mkdtemp()) / "output.png"
image.save(str(out_path))
return out_path