pre converting params to torch allows mega to run in standard colab runtime

This commit is contained in:
Brett Kuprel 2022-06-30 14:54:08 -04:00
parent de97fcf06b
commit b913b58353
5 changed files with 54 additions and 21 deletions

2
README.md vendored
View File

@ -5,7 +5,7 @@
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which can be saved with `torch.save` once the model is loaded), and wandb is only used to download the models. This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which can be saved with `torch.save` once the model is loaded), and wandb is only used to download the models.
It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab (with nonexpendable model in high RAM runtime) It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab
### Setup ### Setup

View File

@ -1,6 +1,5 @@
import os import os
import numpy import numpy
from copy import deepcopy
from typing import Dict from typing import Dict
from flax.traverse_util import flatten_dict from flax.traverse_util import flatten_dict
from flax.serialization import msgpack_restore from flax.serialization import msgpack_restore
@ -105,4 +104,29 @@ def convert_dalle_bart_torch_from_flax_params(
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding') P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
P['embed_positions.weight'] = P.pop('embed_positions.embedding') P['embed_positions.weight'] = P.pop('embed_positions.embedding')
return P return P
def convert_and_save_mega_torch_params(is_mega: bool, model_path: str):
print("converting params to torch")
layer_count = 24 if is_mega else 12
flax_params = load_dalle_bart_flax_params(model_path)
encoder_params = convert_dalle_bart_torch_from_flax_params(
flax_params['encoder'],
layer_count=layer_count,
is_encoder=True
)
decoder_params = convert_dalle_bart_torch_from_flax_params(
flax_params['decoder'],
layer_count=layer_count,
is_encoder=False
)
for i in decoder_params:
decoder_params[i] = decoder_params[i].to(torch.float16)
for i in encoder_params:
encoder_params[i] = encoder_params[i].to(torch.float16)
torch.save(encoder_params, os.path.join(model_path, 'encoder.pt'))
torch.save(decoder_params, os.path.join(model_path, 'decoder.pt'))

View File

@ -10,12 +10,12 @@ class MinDalleBase:
def __init__(self, is_mega: bool): def __init__(self, is_mega: bool):
self.is_mega = is_mega self.is_mega = is_mega
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
model_path = os.path.join('pretrained', model_name) self.model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(model_path)) print("reading files from {}".format(self.model_path))
config_path = os.path.join(model_path, 'config.json') config_path = os.path.join(self.model_path, 'config.json')
vocab_path = os.path.join(model_path, 'vocab.json') vocab_path = os.path.join(self.model_path, 'vocab.json')
merges_path = os.path.join(model_path, 'merges.txt') merges_path = os.path.join(self.model_path, 'merges.txt')
with open(config_path, 'r', encoding='utf8') as f: with open(config_path, 'r', encoding='utf8') as f:
self.config = json.load(f) self.config = json.load(f)
@ -24,7 +24,6 @@ class MinDalleBase:
with open(merges_path, 'r', encoding='utf8') as f: with open(merges_path, 'r', encoding='utf8') as f:
merges = f.read().split("\n")[1:-1] merges = f.read().split("\n")[1:-1]
self.model_params = load_dalle_bart_flax_params(model_path)
self.tokenizer = TextTokenizer(vocab, merges) self.tokenizer = TextTokenizer(vocab, merges)

View File

@ -7,12 +7,15 @@ from .min_dalle_base import MinDalleBase
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
from .load_params import load_dalle_bart_flax_params
class MinDalleFlax(MinDalleBase): class MinDalleFlax(MinDalleBase):
def __init__(self, is_mega: bool, is_reusable: bool = True): def __init__(self, is_mega: bool, is_reusable: bool = True):
super().__init__(is_mega) super().__init__(is_mega)
self.is_reusable = is_reusable self.is_reusable = is_reusable
print("initializing MinDalleFlax") print("initializing MinDalleFlax")
self.model_params = load_dalle_bart_flax_params(self.model_path)
if is_reusable: if is_reusable:
self.init_encoder() self.init_encoder()
self.init_decoder() self.init_decoder()

View File

@ -6,7 +6,10 @@ import torch
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count()) torch.set_num_threads(os.cpu_count())
from .load_params import convert_dalle_bart_torch_from_flax_params from .load_params import (
convert_and_save_mega_torch_params,
load_dalle_bart_flax_params
)
from .min_dalle_base import MinDalleBase from .min_dalle_base import MinDalleBase
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
@ -19,10 +22,22 @@ class MinDalleTorch(MinDalleBase):
is_reusable: bool = True, is_reusable: bool = True,
token_count: int = 256 token_count: int = 256
): ):
print("initializing MinDalleTorch")
super().__init__(is_mega) super().__init__(is_mega)
self.is_reusable = is_reusable self.is_reusable = is_reusable
self.token_count = token_count self.token_count = token_count
print("initializing MinDalleTorch")
if not is_mega:
self.model_params = load_dalle_bart_flax_params(self.model_path)
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
is_converted = os.path.exists(self.encoder_params_path)
is_converted &= os.path.exists(self.decoder_params_path)
if not is_converted:
convert_and_save_mega_torch_params(is_mega, self.model_path)
if is_reusable: if is_reusable:
self.init_encoder() self.init_encoder()
self.init_decoder() self.init_decoder()
@ -39,11 +54,7 @@ class MinDalleTorch(MinDalleBase):
text_token_count = self.config['max_text_length'], text_token_count = self.config['max_text_length'],
glu_embed_count = self.config['encoder_ffn_dim'] glu_embed_count = self.config['encoder_ffn_dim']
) )
params = convert_dalle_bart_torch_from_flax_params( params = torch.load(self.encoder_params_path)
self.model_params.pop('encoder'),
layer_count=self.config['encoder_layers'],
is_encoder=True
)
self.encoder.load_state_dict(params, strict=False) self.encoder.load_state_dict(params, strict=False)
del params del params
if torch.cuda.is_available(): self.encoder = self.encoder.cuda() if torch.cuda.is_available(): self.encoder = self.encoder.cuda()
@ -63,11 +74,7 @@ class MinDalleTorch(MinDalleBase):
start_token = self.config['decoder_start_token_id'], start_token = self.config['decoder_start_token_id'],
is_verbose = True is_verbose = True
) )
params = convert_dalle_bart_torch_from_flax_params( params = torch.load(self.decoder_params_path)
self.model_params.pop('decoder'),
layer_count=self.config['decoder_layers'],
is_encoder=False
)
self.decoder.load_state_dict(params, strict=False) self.decoder.load_state_dict(params, strict=False)
del params del params
if torch.cuda.is_available(): self.decoder = self.decoder.cuda() if torch.cuda.is_available(): self.decoder = self.decoder.cuda()