separate setup processes for flax and torch
This commit is contained in:
parent
7bf76deafb
commit
09a0f85b8e
3
README.md
vendored
3
README.md
vendored
|
@ -9,7 +9,8 @@ It currently takes **7.4 seconds** to generate an image with DALL·E Mega with P
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually here:
|
Run either `sh setup_torch.sh` or `sh setup_flax.sh` to install dependencies and download pretrained models. The torch models can be manually downloaded [here](https://huggingface.co/kuprel/min-dalle).
|
||||||
|
The flax models can be manually downloaded here:
|
||||||
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
||||||
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
||||||
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
||||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_num_threads(os.cpu_count())
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
|
||||||
from .load_params import convert_and_save_torch_params
|
|
||||||
from .min_dalle_base import MinDalleBase
|
from .min_dalle_base import MinDalleBase
|
||||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||||
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||||
|
@ -29,12 +28,6 @@ class MinDalleTorch(MinDalleBase):
|
||||||
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
|
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
|
||||||
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt')
|
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt')
|
||||||
|
|
||||||
is_converted = os.path.exists(self.encoder_params_path)
|
|
||||||
is_converted &= os.path.exists(self.decoder_params_path)
|
|
||||||
is_converted &= os.path.exists(self.detoker_params_path)
|
|
||||||
if not is_converted:
|
|
||||||
convert_and_save_torch_params(is_mega, self.model_path)
|
|
||||||
|
|
||||||
if is_reusable:
|
if is_reusable:
|
||||||
self.init_encoder()
|
self.init_encoder()
|
||||||
self.init_decoder()
|
self.init_decoder()
|
||||||
|
|
3
requirements.txt
vendored
3
requirements.txt
vendored
|
@ -1,3 +0,0 @@
|
||||||
torch
|
|
||||||
flax
|
|
||||||
wandb
|
|
0
setup.sh → setup_flax.sh
vendored
0
setup.sh → setup_flax.sh
vendored
14
setup_torch.sh
vendored
Normal file
14
setup_torch.sh
vendored
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
pip install torch
|
||||||
|
|
||||||
|
mkdir -p ./pretrained/dalle_bart_mega/
|
||||||
|
mkdir -p ./pretrained/vqgan/
|
||||||
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/config.json -L --output ./pretrained/dalle_bart_mega/config.json
|
||||||
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab.json -L --output ./pretrained/dalle_bart_mega/vocab.json
|
||||||
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges.txt -L --output ./pretrained/dalle_bart_mega/merges.txt
|
||||||
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder.pt -L --output ./pretrained/dalle_bart_mega/encoder.pt
|
||||||
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder.pt -L --output ./pretrained/dalle_bart_mega/decoder.pt
|
||||||
|
curl https://huggingface.co/kuprel/min-dalle/resolve/main/detoker.pt -L --output ./pretrained/vqgan/detoker.pt
|
Loading…
Reference in New Issue
Block a user