Merge pull request #38 from chenxwh/replicate
Add Web Demo & Docker environment
This commit is contained in:
commit
15b0c03485
3
README.md
vendored
3
README.md
vendored
|
@ -1,6 +1,7 @@
|
|||
# min(DALL·E)
|
||||
|
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) \
|
||||
Try Replicate web demo here [![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
|
||||
|
||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).
|
||||
|
||||
|
|
14
cog.yaml
vendored
Normal file
14
cog.yaml
vendored
Normal file
|
@ -0,0 +1,14 @@
|
|||
build:
|
||||
cuda: "11.0"
|
||||
gpu: true
|
||||
python_version: "3.8"
|
||||
system_packages:
|
||||
- "libgl1-mesa-glx"
|
||||
- "libglib2.0-0"
|
||||
python_packages:
|
||||
- "ipython==7.21.0"
|
||||
- "torch==1.10.1"
|
||||
- "flax==0.4.2"
|
||||
- "wandb==0.12.16"
|
||||
|
||||
predict: "predict.py:Predictor"
|
53
predict.py
Normal file
53
predict.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import tempfile
|
||||
from PIL import Image
|
||||
from cog import BasePredictor, Path, Input
|
||||
|
||||
from min_dalle.generate_image import load_dalle_bart_metadata, tokenize_text
|
||||
from min_dalle.load_params import load_dalle_bart_flax_params
|
||||
from min_dalle.min_dalle_torch import generate_image_tokens_torch, detokenize_torch
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self):
|
||||
self.model_path = {
|
||||
"mini": "pretrained/dalle_bart_mini",
|
||||
"mega": "pretrained/dalle_bart_mega",
|
||||
}
|
||||
self.configs = {
|
||||
k: load_dalle_bart_metadata(self.model_path[k])
|
||||
for k in self.model_path.keys()
|
||||
}
|
||||
|
||||
def predict(
|
||||
self,
|
||||
text: str = Input(
|
||||
description="Text for generating images.",
|
||||
),
|
||||
model: str = Input(
|
||||
choices=["mini", "mega"],
|
||||
description="Choose mini or mega model.",
|
||||
),
|
||||
seed: int = Input(
|
||||
description="Specify the seed.",
|
||||
),
|
||||
) -> Path:
|
||||
|
||||
config, vocab, merges = self.configs[model]
|
||||
text_tokens = tokenize_text(text, config, vocab, merges)
|
||||
params_dalle_bart = load_dalle_bart_flax_params(self.model_path[model])
|
||||
|
||||
image_token_count = config["image_length"]
|
||||
image_tokens = generate_image_tokens_torch(
|
||||
text_tokens=text_tokens,
|
||||
seed=seed,
|
||||
config=config,
|
||||
params=params_dalle_bart,
|
||||
image_token_count=image_token_count,
|
||||
)
|
||||
|
||||
image = detokenize_torch(image_tokens, is_torch=True)
|
||||
image = Image.fromarray(image)
|
||||
out_path = Path(tempfile.mkdtemp()) / "output.png"
|
||||
image.save(str(out_path))
|
||||
|
||||
return out_path
|
Loading…
Reference in New Issue
Block a user