diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index 2607026..adaa6f7 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -13,12 +13,17 @@ class MinDalle: model_path = os.path.join('pretrained', model_name) print("reading files from {}".format(model_path)) - with open(os.path.join(model_path, 'config.json'), 'r') as f: + config_path = os.path.join(model_path, 'config.json') + vocab_path = os.path.join(model_path, 'vocab.json') + merges_path = os.path.join(model_path, 'merges.txt') + + with open(config_path, 'r', encoding='utf8') as f: self.config = json.load(f) - with open(os.path.join(model_path, 'vocab.json'), 'r') as f: + with open(vocab_path, 'r', encoding='utf8') as f: vocab = json.load(f) - with open(os.path.join(model_path, 'merges.txt'), 'r') as f: + with open(merges_path, 'r', encoding='utf8') as f: merges = f.read().split("\n")[1:-1] + self.model_params = load_dalle_bart_flax_params(model_path) self.tokenizer = TextTokenizer(vocab, merges)