parent
1ef9b0b929
commit
eb9f4c6b3b
3 changed files with 69 additions and 1 deletions
@ -0,0 +1,14 @@ |
||||
build: |
||||
cuda: "11.0" |
||||
gpu: true |
||||
python_version: "3.8" |
||||
system_packages: |
||||
- "libgl1-mesa-glx" |
||||
- "libglib2.0-0" |
||||
python_packages: |
||||
- "ipython==7.21.0" |
||||
- "torch==1.10.1" |
||||
- "flax==0.4.2" |
||||
- "wandb==0.12.16" |
||||
|
||||
predict: "predict.py:Predictor" |
@ -0,0 +1,53 @@ |
||||
import tempfile |
||||
from PIL import Image |
||||
from cog import BasePredictor, Path, Input |
||||
|
||||
from min_dalle.generate_image import load_dalle_bart_metadata, tokenize_text |
||||
from min_dalle.load_params import load_dalle_bart_flax_params |
||||
from min_dalle.min_dalle_torch import generate_image_tokens_torch, detokenize_torch |
||||
|
||||
|
||||
class Predictor(BasePredictor): |
||||
def setup(self): |
||||
self.model_path = { |
||||
"mini": "pretrained/dalle_bart_mini", |
||||
"mega": "pretrained/dalle_bart_mega", |
||||
} |
||||
self.configs = { |
||||
k: load_dalle_bart_metadata(self.model_path[k]) |
||||
for k in self.model_path.keys() |
||||
} |
||||
|
||||
def predict( |
||||
self, |
||||
text: str = Input( |
||||
description="Text for generating images.", |
||||
), |
||||
model: str = Input( |
||||
choices=["mini", "mega"], |
||||
description="Choose mini or mega model.", |
||||
), |
||||
seed: int = Input( |
||||
description="Specify the seed.", |
||||
), |
||||
) -> Path: |
||||
|
||||
config, vocab, merges = self.configs[model] |
||||
text_tokens = tokenize_text(text, config, vocab, merges) |
||||
params_dalle_bart = load_dalle_bart_flax_params(self.model_path[model]) |
||||
|
||||
image_token_count = config["image_length"] |
||||
image_tokens = generate_image_tokens_torch( |
||||
text_tokens=text_tokens, |
||||
seed=seed, |
||||
config=config, |
||||
params=params_dalle_bart, |
||||
image_token_count=image_token_count, |
||||
) |
||||
|
||||
image = detokenize_torch(image_tokens, is_torch=True) |
||||
image = Image.fromarray(image) |
||||
out_path = Path(tempfile.mkdtemp()) / "output.png" |
||||
image.save(str(out_path)) |
||||
|
||||
return out_path |
Loading…
Reference in new issue