update readme
This commit is contained in:
parent
f0c4fc7350
commit
cdcf3d3964
9
README.md
vendored
9
README.md
vendored
|
@ -42,13 +42,14 @@ model = MinDalle(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
The required models will be downloaded to `models_root` if they are not already there. Set the `dtype` to `torch.float16` to save GPU memory. If you have an Ampere architecture GPU you can use `torch.bfloat16`. Set the `device` to either "cuda" or "cpu". Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the `top_k` most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the `temperature`.
|
The required models will be downloaded to `models_root` if they are not already there. Set the `dtype` to `torch.float16` to save GPU memory. If you have an Ampere architecture GPU you can use `torch.bfloat16`. Set the `device` to either "cuda" or "cpu". Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the `top_k` most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the `temperature`. If `is_seamless` is true, the images grid will be tiled in token space not pixel space.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
image = model.generate_image(
|
image = model.generate_image(
|
||||||
text='Nuclear explosion broccoli',
|
text='Nuclear explosion broccoli',
|
||||||
seed=-1,
|
seed=-1,
|
||||||
grid_size=4,
|
grid_size=4,
|
||||||
|
is_seamless=False,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
top_k=256,
|
top_k=256,
|
||||||
supercondition_factor=32,
|
supercondition_factor=32,
|
||||||
|
@ -69,7 +70,8 @@ The images can also be generated as a `FloatTensor` in case you want to process
|
||||||
images = model.generate_images(
|
images = model.generate_images(
|
||||||
text='Nuclear explosion broccoli',
|
text='Nuclear explosion broccoli',
|
||||||
seed=-1,
|
seed=-1,
|
||||||
image_count=7,
|
grid_size=3,
|
||||||
|
is_seamless=False,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
top_k=256,
|
top_k=256,
|
||||||
supercondition_factor=16,
|
supercondition_factor=16,
|
||||||
|
@ -96,7 +98,8 @@ image_stream = model.generate_image_stream(
|
||||||
text='Dali painting of WALL·E',
|
text='Dali painting of WALL·E',
|
||||||
seed=-1,
|
seed=-1,
|
||||||
grid_size=3,
|
grid_size=3,
|
||||||
log2_mid_count=3,
|
progressive_outputs=True,
|
||||||
|
is_seamless=False,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
top_k=256,
|
top_k=256,
|
||||||
supercondition_factor=16,
|
supercondition_factor=16,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
from min_dalle import MinDalle
|
from min_dalle import MinDalle
|
||||||
import tempfile
|
import tempfile
|
||||||
import torch, torch.backends.cudnn
|
import torch, torch.backends.cudnn
|
||||||
|
@ -22,13 +23,10 @@ class ReplicatePredictor(BasePredictor):
|
||||||
progressive_outputs: bool = Input(default=True),
|
progressive_outputs: bool = Input(default=True),
|
||||||
seamless: bool = Input(default=False),
|
seamless: bool = Input(default=False),
|
||||||
grid_size: int = Input(ge=1, le=9, default=5),
|
grid_size: int = Input(ge=1, le=9, default=5),
|
||||||
temperature: str = Input(
|
temperature: float = Input(
|
||||||
choices=(
|
ge=0.01,
|
||||||
['1/{}'.format(2 ** i) for i in range(4, 0, -1)] +
|
le=16,
|
||||||
[str(2 ** i) for i in range(5)]
|
default=4
|
||||||
),
|
|
||||||
default='4',
|
|
||||||
description='Advanced Setting, see Readme below if interested.'
|
|
||||||
),
|
),
|
||||||
top_k: int = Input(
|
top_k: int = Input(
|
||||||
choices=[2 ** i for i in range(15)],
|
choices=[2 ** i for i in range(15)],
|
||||||
|
@ -46,8 +44,8 @@ class ReplicatePredictor(BasePredictor):
|
||||||
seed = -1,
|
seed = -1,
|
||||||
grid_size = grid_size,
|
grid_size = grid_size,
|
||||||
progressive_outputs = progressive_outputs,
|
progressive_outputs = progressive_outputs,
|
||||||
is_seamless=seamless,
|
is_seamless = seamless,
|
||||||
temperature = eval(temperature),
|
temperature = temperature,
|
||||||
supercondition_factor = float(supercondition_factor),
|
supercondition_factor = float(supercondition_factor),
|
||||||
top_k = top_k,
|
top_k = top_k,
|
||||||
is_verbose = True
|
is_verbose = True
|
||||||
|
|
Loading…
Reference in New Issue
Block a user