diff --git a/min_dalle/load_params.py b/min_dalle/load_params.py index bfdfd93..1ce2998 100644 --- a/min_dalle/load_params.py +++ b/min_dalle/load_params.py @@ -129,7 +129,7 @@ def convert_and_save_torch_params(is_mega: bool, model_path: str): encoder_params[i] = encoder_params[i].to(torch.float16) detoker_params = load_vqgan_torch_params('./pretrained/vqgan') - detoker_path = os.path.join('pretrained', 'vqgan', 'detokenizer.pt') + detoker_path = os.path.join('pretrained', 'vqgan', 'detoker.pt') torch.save(encoder_params, os.path.join(model_path, 'encoder.pt')) torch.save(decoder_params, os.path.join(model_path, 'decoder.pt')) diff --git a/min_dalle/min_dalle_base.py b/min_dalle/min_dalle_base.py index 0d196b3..c1e8e03 100644 --- a/min_dalle/min_dalle_base.py +++ b/min_dalle/min_dalle_base.py @@ -3,7 +3,6 @@ import json import numpy from .text_tokenizer import TextTokenizer -from .models.vqgan_detokenizer import VQGanDetokenizer class MinDalleBase: def __init__(self, is_mega: bool): diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 1593818..2da39af 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -27,7 +27,7 @@ class MinDalleTorch(MinDalleBase): self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt') self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt') - self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detokenizer.pt') + self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt') is_converted = os.path.exists(self.encoder_params_path) is_converted &= os.path.exists(self.decoder_params_path)