From 09a0f85b8ee3525c075d64020e1f7595f3a5c4e3 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Fri, 1 Jul 2022 11:08:33 -0400 Subject: [PATCH] separate setup processes for flax and torch --- README.md | 3 ++- min_dalle/min_dalle_torch.py | 7 ------- requirements.txt | 3 --- setup.sh => setup_flax.sh | 0 setup_torch.sh | 14 ++++++++++++++ 5 files changed, 16 insertions(+), 11 deletions(-) delete mode 100644 requirements.txt rename setup.sh => setup_flax.sh (100%) create mode 100644 setup_torch.sh diff --git a/README.md b/README.md index 266a49a..8d6a3cb 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,8 @@ It currently takes **7.4 seconds** to generate an image with DALL·E Mega with P ### Setup -Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually here: +Run either `sh setup_torch.sh` or `sh setup_flax.sh` to install dependencies and download pretrained models. The torch models can be manually downloaded [here](https://huggingface.co/kuprel/min-dalle). +The flax models can be manually downloaded here: [VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384), [DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files), [DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files) diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 2da39af..55d45c1 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -6,7 +6,6 @@ import torch torch.set_grad_enabled(False) torch.set_num_threads(os.cpu_count()) -from .load_params import convert_and_save_torch_params from .min_dalle_base import MinDalleBase from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch @@ -29,12 +28,6 @@ class MinDalleTorch(MinDalleBase): self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt') self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt') - is_converted = os.path.exists(self.encoder_params_path) - is_converted &= os.path.exists(self.decoder_params_path) - is_converted &= os.path.exists(self.detoker_params_path) - if not is_converted: - convert_and_save_torch_params(is_mega, self.model_path) - if is_reusable: self.init_encoder() self.init_decoder() diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 6dd6b09..0000000 --- a/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -torch -flax -wandb diff --git a/setup.sh b/setup_flax.sh similarity index 100% rename from setup.sh rename to setup_flax.sh diff --git a/setup_torch.sh b/setup_torch.sh new file mode 100644 index 0000000..568d845 --- /dev/null +++ b/setup_torch.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +set -e + +pip install torch + +mkdir -p ./pretrained/dalle_bart_mega/ +mkdir -p ./pretrained/vqgan/ +curl https://huggingface.co/kuprel/min-dalle/resolve/main/config.json -L --output ./pretrained/dalle_bart_mega/config.json +curl https://huggingface.co/kuprel/min-dalle/resolve/main/vocab.json -L --output ./pretrained/dalle_bart_mega/vocab.json +curl https://huggingface.co/kuprel/min-dalle/resolve/main/merges.txt -L --output ./pretrained/dalle_bart_mega/merges.txt +curl https://huggingface.co/kuprel/min-dalle/resolve/main/encoder.pt -L --output ./pretrained/dalle_bart_mega/encoder.pt +curl https://huggingface.co/kuprel/min-dalle/resolve/main/decoder.pt -L --output ./pretrained/dalle_bart_mega/decoder.pt +curl https://huggingface.co/kuprel/min-dalle/resolve/main/detoker.pt -L --output ./pretrained/vqgan/detoker.pt \ No newline at end of file