diff --git a/README.md b/README.md index 6b9ff6b..e3f1585 100644 --- a/README.md +++ b/README.md @@ -42,13 +42,14 @@ model = MinDalle( ) ``` -The required models will be downloaded to `models_root` if they are not already there. Set the `dtype` to `torch.float16` to save GPU memory. If you have an Ampere architecture GPU you can use `torch.bfloat16`. Set the `device` to either "cuda" or "cpu". Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the `top_k` most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the `temperature`. +The required models will be downloaded to `models_root` if they are not already there. Set the `dtype` to `torch.float16` to save GPU memory. If you have an Ampere architecture GPU you can use `torch.bfloat16`. Set the `device` to either "cuda" or "cpu". Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the `top_k` most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the `temperature`. If `is_seamless` is true, the images grid will be tiled in token space not pixel space. ```python image = model.generate_image( text='Nuclear explosion broccoli', seed=-1, grid_size=4, + is_seamless=False, temperature=1, top_k=256, supercondition_factor=32, @@ -69,7 +70,8 @@ The images can also be generated as a `FloatTensor` in case you want to process images = model.generate_images( text='Nuclear explosion broccoli', seed=-1, - image_count=7, + grid_size=3, + is_seamless=False, temperature=1, top_k=256, supercondition_factor=16, @@ -96,7 +98,8 @@ image_stream = model.generate_image_stream( text='Dali painting of WALL·E', seed=-1, grid_size=3, - log2_mid_count=3, + progressive_outputs=True, + is_seamless=False, temperature=1, top_k=256, supercondition_factor=16, diff --git a/replicate_predictor.py b/replicate_predictor.py index 24f4a7f..c68ddd8 100644 --- a/replicate_predictor.py +++ b/replicate_predictor.py @@ -1,3 +1,4 @@ + from min_dalle import MinDalle import tempfile import torch, torch.backends.cudnn @@ -22,13 +23,10 @@ class ReplicatePredictor(BasePredictor): progressive_outputs: bool = Input(default=True), seamless: bool = Input(default=False), grid_size: int = Input(ge=1, le=9, default=5), - temperature: str = Input( - choices=( - ['1/{}'.format(2 ** i) for i in range(4, 0, -1)] + - [str(2 ** i) for i in range(5)] - ), - default='4', - description='Advanced Setting, see Readme below if interested.' + temperature: float = Input( + ge=0.01, + le=16, + default=4 ), top_k: int = Input( choices=[2 ** i for i in range(15)], @@ -46,8 +44,8 @@ class ReplicatePredictor(BasePredictor): seed = -1, grid_size = grid_size, progressive_outputs = progressive_outputs, - is_seamless=seamless, - temperature = eval(temperature), + is_seamless = seamless, + temperature = temperature, supercondition_factor = float(supercondition_factor), top_k = top_k, is_verbose = True