updated readme
This commit is contained in:
parent
2311a1af7b
commit
08b158d580
2
README.md
vendored
2
README.md
vendored
|
@ -3,7 +3,7 @@
|
||||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
||||||
[![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
|
[![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
|
||||||
|
|
||||||
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which can be saved with `torch.save` once the model is loaded), and wandb is only used to download the models.
|
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which are saved with `torch.save` the first time the model is loaded), and wandb is only used to download the models.
|
||||||
|
|
||||||
It currently takes about 10 seconds to generate an avocado armchair with DALL·E Mega in PyTorch on Colab with a reusable model and high-RAM GPU runtime.
|
It currently takes about 10 seconds to generate an avocado armchair with DALL·E Mega in PyTorch on Colab with a reusable model and high-RAM GPU runtime.
|
||||||
|
|
||||||
|
|
|
@ -107,7 +107,7 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||||
return P
|
return P
|
||||||
|
|
||||||
|
|
||||||
def convert_and_save_mega_torch_params(is_mega: bool, model_path: str):
|
def convert_and_save_torch_params(is_mega: bool, model_path: str):
|
||||||
print("converting params to torch")
|
print("converting params to torch")
|
||||||
layer_count = 24 if is_mega else 12
|
layer_count = 24 if is_mega else 12
|
||||||
flax_params = load_dalle_bart_flax_params(model_path)
|
flax_params = load_dalle_bart_flax_params(model_path)
|
||||||
|
|
|
@ -7,7 +7,7 @@ torch.set_grad_enabled(False)
|
||||||
torch.set_num_threads(os.cpu_count())
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
|
||||||
from .load_params import (
|
from .load_params import (
|
||||||
convert_and_save_mega_torch_params,
|
convert_and_save_torch_params,
|
||||||
load_dalle_bart_flax_params
|
load_dalle_bart_flax_params
|
||||||
)
|
)
|
||||||
from .min_dalle_base import MinDalleBase
|
from .min_dalle_base import MinDalleBase
|
||||||
|
@ -36,7 +36,7 @@ class MinDalleTorch(MinDalleBase):
|
||||||
is_converted = os.path.exists(self.encoder_params_path)
|
is_converted = os.path.exists(self.encoder_params_path)
|
||||||
is_converted &= os.path.exists(self.decoder_params_path)
|
is_converted &= os.path.exists(self.decoder_params_path)
|
||||||
if not is_converted:
|
if not is_converted:
|
||||||
convert_and_save_mega_torch_params(is_mega, self.model_path)
|
convert_and_save_torch_params(is_mega, self.model_path)
|
||||||
|
|
||||||
if is_reusable:
|
if is_reusable:
|
||||||
self.init_encoder()
|
self.init_encoder()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user