|
|
|
@ -2,12 +2,13 @@ import argparse |
|
|
|
|
import os |
|
|
|
|
from PIL import Image |
|
|
|
|
from min_dalle import MinDalle |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
parser.add_argument('--mega', action='store_true') |
|
|
|
|
parser.add_argument('--no-mega', dest='mega', action='store_false') |
|
|
|
|
parser.set_defaults(mega=False) |
|
|
|
|
parser.add_argument('--fp16', action='store_true') |
|
|
|
|
parser.add_argument('--text', type=str, default='Dali painting of WALL·E') |
|
|
|
|
parser.add_argument('--seed', type=int, default=-1) |
|
|
|
|
parser.add_argument('--grid-size', type=int, default=1) |
|
|
|
@ -41,13 +42,15 @@ def generate_image( |
|
|
|
|
grid_size: int, |
|
|
|
|
top_k: int, |
|
|
|
|
image_path: str, |
|
|
|
|
models_root: str |
|
|
|
|
models_root: str, |
|
|
|
|
fp16: bool, |
|
|
|
|
): |
|
|
|
|
model = MinDalle( |
|
|
|
|
is_mega=is_mega, |
|
|
|
|
models_root=models_root, |
|
|
|
|
is_reusable=False, |
|
|
|
|
is_verbose=True |
|
|
|
|
is_verbose=True, |
|
|
|
|
dtype=torch.float16 if fp16 else torch.float32 |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
image = model.generate_image( |
|
|
|
@ -71,5 +74,6 @@ if __name__ == '__main__': |
|
|
|
|
grid_size=args.grid_size, |
|
|
|
|
top_k=args.top_k, |
|
|
|
|
image_path=args.image_path, |
|
|
|
|
models_root=args.models_root |
|
|
|
|
models_root=args.models_root, |
|
|
|
|
fp16=args.fp16, |
|
|
|
|
) |