From b913b58353afaf33ca78d405a17643516bcd742f Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Thu, 30 Jun 2022 14:54:08 -0400 Subject: [PATCH] pre converting params to torch allows mega to run in standard colab runtime --- README.md | 2 +- min_dalle/load_params.py | 28 ++++++++++++++++++++++++++-- min_dalle/min_dalle_base.py | 11 +++++------ min_dalle/min_dalle_flax.py | 3 +++ min_dalle/min_dalle_torch.py | 31 +++++++++++++++++++------------ 5 files changed, 54 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index a96edbd..46724cc 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which can be saved with `torch.save` once the model is loaded), and wandb is only used to download the models. -It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab (with nonexpendable model in high RAM runtime) +It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab ### Setup diff --git a/min_dalle/load_params.py b/min_dalle/load_params.py index 38deec2..4c647d5 100644 --- a/min_dalle/load_params.py +++ b/min_dalle/load_params.py @@ -1,6 +1,5 @@ import os import numpy -from copy import deepcopy from typing import Dict from flax.traverse_util import flatten_dict from flax.serialization import msgpack_restore @@ -105,4 +104,29 @@ def convert_dalle_bart_torch_from_flax_params( P['embed_tokens.weight'] = P.pop('embed_tokens.embedding') P['embed_positions.weight'] = P.pop('embed_positions.embedding') - return P \ No newline at end of file + return P + + +def convert_and_save_mega_torch_params(is_mega: bool, model_path: str): + print("converting params to torch") + layer_count = 24 if is_mega else 12 + flax_params = load_dalle_bart_flax_params(model_path) + encoder_params = convert_dalle_bart_torch_from_flax_params( + flax_params['encoder'], + layer_count=layer_count, + is_encoder=True + ) + decoder_params = convert_dalle_bart_torch_from_flax_params( + flax_params['decoder'], + layer_count=layer_count, + is_encoder=False + ) + + for i in decoder_params: + decoder_params[i] = decoder_params[i].to(torch.float16) + + for i in encoder_params: + encoder_params[i] = encoder_params[i].to(torch.float16) + + torch.save(encoder_params, os.path.join(model_path, 'encoder.pt')) + torch.save(decoder_params, os.path.join(model_path, 'decoder.pt')) \ No newline at end of file diff --git a/min_dalle/min_dalle_base.py b/min_dalle/min_dalle_base.py index aa7cd22..1bde741 100644 --- a/min_dalle/min_dalle_base.py +++ b/min_dalle/min_dalle_base.py @@ -10,12 +10,12 @@ class MinDalleBase: def __init__(self, is_mega: bool): self.is_mega = is_mega model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') - model_path = os.path.join('pretrained', model_name) + self.model_path = os.path.join('pretrained', model_name) - print("reading files from {}".format(model_path)) - config_path = os.path.join(model_path, 'config.json') - vocab_path = os.path.join(model_path, 'vocab.json') - merges_path = os.path.join(model_path, 'merges.txt') + print("reading files from {}".format(self.model_path)) + config_path = os.path.join(self.model_path, 'config.json') + vocab_path = os.path.join(self.model_path, 'vocab.json') + merges_path = os.path.join(self.model_path, 'merges.txt') with open(config_path, 'r', encoding='utf8') as f: self.config = json.load(f) @@ -24,7 +24,6 @@ class MinDalleBase: with open(merges_path, 'r', encoding='utf8') as f: merges = f.read().split("\n")[1:-1] - self.model_params = load_dalle_bart_flax_params(model_path) self.tokenizer = TextTokenizer(vocab, merges) diff --git a/min_dalle/min_dalle_flax.py b/min_dalle/min_dalle_flax.py index bc5ac81..176ce6b 100644 --- a/min_dalle/min_dalle_flax.py +++ b/min_dalle/min_dalle_flax.py @@ -7,12 +7,15 @@ from .min_dalle_base import MinDalleBase from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax +from .load_params import load_dalle_bart_flax_params + class MinDalleFlax(MinDalleBase): def __init__(self, is_mega: bool, is_reusable: bool = True): super().__init__(is_mega) self.is_reusable = is_reusable print("initializing MinDalleFlax") + self.model_params = load_dalle_bart_flax_params(self.model_path) if is_reusable: self.init_encoder() self.init_decoder() diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 525ef9f..829dff6 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -6,7 +6,10 @@ import torch torch.set_grad_enabled(False) torch.set_num_threads(os.cpu_count()) -from .load_params import convert_dalle_bart_torch_from_flax_params +from .load_params import ( + convert_and_save_mega_torch_params, + load_dalle_bart_flax_params +) from .min_dalle_base import MinDalleBase from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch @@ -19,10 +22,22 @@ class MinDalleTorch(MinDalleBase): is_reusable: bool = True, token_count: int = 256 ): + print("initializing MinDalleTorch") super().__init__(is_mega) self.is_reusable = is_reusable self.token_count = token_count - print("initializing MinDalleTorch") + + if not is_mega: + self.model_params = load_dalle_bart_flax_params(self.model_path) + + self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt') + self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt') + + is_converted = os.path.exists(self.encoder_params_path) + is_converted &= os.path.exists(self.decoder_params_path) + if not is_converted: + convert_and_save_mega_torch_params(is_mega, self.model_path) + if is_reusable: self.init_encoder() self.init_decoder() @@ -39,11 +54,7 @@ class MinDalleTorch(MinDalleBase): text_token_count = self.config['max_text_length'], glu_embed_count = self.config['encoder_ffn_dim'] ) - params = convert_dalle_bart_torch_from_flax_params( - self.model_params.pop('encoder'), - layer_count=self.config['encoder_layers'], - is_encoder=True - ) + params = torch.load(self.encoder_params_path) self.encoder.load_state_dict(params, strict=False) del params if torch.cuda.is_available(): self.encoder = self.encoder.cuda() @@ -63,11 +74,7 @@ class MinDalleTorch(MinDalleBase): start_token = self.config['decoder_start_token_id'], is_verbose = True ) - params = convert_dalle_bart_torch_from_flax_params( - self.model_params.pop('decoder'), - layer_count=self.config['decoder_layers'], - is_encoder=False - ) + params = torch.load(self.decoder_params_path) self.decoder.load_state_dict(params, strict=False) del params if torch.cuda.is_available(): self.decoder = self.decoder.cuda()