update readme

This commit is contained in:
Brett Kuprel 2022-07-17 15:29:23 -04:00
parent f0c4fc7350
commit cdcf3d3964
2 changed files with 13 additions and 12 deletions

9
README.md vendored
View File

@ -42,13 +42,14 @@ model = MinDalle(
) )
``` ```
The required models will be downloaded to `models_root` if they are not already there. Set the `dtype` to `torch.float16` to save GPU memory. If you have an Ampere architecture GPU you can use `torch.bfloat16`. Set the `device` to either "cuda" or "cpu". Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the `top_k` most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the `temperature`. The required models will be downloaded to `models_root` if they are not already there. Set the `dtype` to `torch.float16` to save GPU memory. If you have an Ampere architecture GPU you can use `torch.bfloat16`. Set the `device` to either "cuda" or "cpu". Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the `top_k` most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the `temperature`. If `is_seamless` is true, the images grid will be tiled in token space not pixel space.
```python ```python
image = model.generate_image( image = model.generate_image(
text='Nuclear explosion broccoli', text='Nuclear explosion broccoli',
seed=-1, seed=-1,
grid_size=4, grid_size=4,
is_seamless=False,
temperature=1, temperature=1,
top_k=256, top_k=256,
supercondition_factor=32, supercondition_factor=32,
@ -69,7 +70,8 @@ The images can also be generated as a `FloatTensor` in case you want to process
images = model.generate_images( images = model.generate_images(
text='Nuclear explosion broccoli', text='Nuclear explosion broccoli',
seed=-1, seed=-1,
image_count=7, grid_size=3,
is_seamless=False,
temperature=1, temperature=1,
top_k=256, top_k=256,
supercondition_factor=16, supercondition_factor=16,
@ -96,7 +98,8 @@ image_stream = model.generate_image_stream(
text='Dali painting of WALL·E', text='Dali painting of WALL·E',
seed=-1, seed=-1,
grid_size=3, grid_size=3,
log2_mid_count=3, progressive_outputs=True,
is_seamless=False,
temperature=1, temperature=1,
top_k=256, top_k=256,
supercondition_factor=16, supercondition_factor=16,

View File

@ -1,3 +1,4 @@
from min_dalle import MinDalle from min_dalle import MinDalle
import tempfile import tempfile
import torch, torch.backends.cudnn import torch, torch.backends.cudnn
@ -22,13 +23,10 @@ class ReplicatePredictor(BasePredictor):
progressive_outputs: bool = Input(default=True), progressive_outputs: bool = Input(default=True),
seamless: bool = Input(default=False), seamless: bool = Input(default=False),
grid_size: int = Input(ge=1, le=9, default=5), grid_size: int = Input(ge=1, le=9, default=5),
temperature: str = Input( temperature: float = Input(
choices=( ge=0.01,
['1/{}'.format(2 ** i) for i in range(4, 0, -1)] + le=16,
[str(2 ** i) for i in range(5)] default=4
),
default='4',
description='Advanced Setting, see Readme below if interested.'
), ),
top_k: int = Input( top_k: int = Input(
choices=[2 ** i for i in range(15)], choices=[2 ** i for i in range(15)],
@ -46,8 +44,8 @@ class ReplicatePredictor(BasePredictor):
seed = -1, seed = -1,
grid_size = grid_size, grid_size = grid_size,
progressive_outputs = progressive_outputs, progressive_outputs = progressive_outputs,
is_seamless=seamless, is_seamless = seamless,
temperature = eval(temperature), temperature = temperature,
supercondition_factor = float(supercondition_factor), supercondition_factor = float(supercondition_factor),
top_k = top_k, top_k = top_k,
is_verbose = True is_verbose = True