pre converting params to torch allows mega to run in standard colab runtime
This commit is contained in:
parent
de97fcf06b
commit
b913b58353
2
README.md
vendored
2
README.md
vendored
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which can be saved with `torch.save` once the model is loaded), and wandb is only used to download the models.
|
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which can be saved with `torch.save` once the model is loaded), and wandb is only used to download the models.
|
||||||
|
|
||||||
It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab (with nonexpendable model in high RAM runtime)
|
It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import numpy
|
import numpy
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from flax.traverse_util import flatten_dict
|
from flax.traverse_util import flatten_dict
|
||||||
from flax.serialization import msgpack_restore
|
from flax.serialization import msgpack_restore
|
||||||
|
@ -106,3 +105,28 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||||
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
|
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
|
||||||
P['embed_positions.weight'] = P.pop('embed_positions.embedding')
|
P['embed_positions.weight'] = P.pop('embed_positions.embedding')
|
||||||
return P
|
return P
|
||||||
|
|
||||||
|
|
||||||
|
def convert_and_save_mega_torch_params(is_mega: bool, model_path: str):
|
||||||
|
print("converting params to torch")
|
||||||
|
layer_count = 24 if is_mega else 12
|
||||||
|
flax_params = load_dalle_bart_flax_params(model_path)
|
||||||
|
encoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||||
|
flax_params['encoder'],
|
||||||
|
layer_count=layer_count,
|
||||||
|
is_encoder=True
|
||||||
|
)
|
||||||
|
decoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||||
|
flax_params['decoder'],
|
||||||
|
layer_count=layer_count,
|
||||||
|
is_encoder=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in decoder_params:
|
||||||
|
decoder_params[i] = decoder_params[i].to(torch.float16)
|
||||||
|
|
||||||
|
for i in encoder_params:
|
||||||
|
encoder_params[i] = encoder_params[i].to(torch.float16)
|
||||||
|
|
||||||
|
torch.save(encoder_params, os.path.join(model_path, 'encoder.pt'))
|
||||||
|
torch.save(decoder_params, os.path.join(model_path, 'decoder.pt'))
|
|
@ -10,12 +10,12 @@ class MinDalleBase:
|
||||||
def __init__(self, is_mega: bool):
|
def __init__(self, is_mega: bool):
|
||||||
self.is_mega = is_mega
|
self.is_mega = is_mega
|
||||||
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
||||||
model_path = os.path.join('pretrained', model_name)
|
self.model_path = os.path.join('pretrained', model_name)
|
||||||
|
|
||||||
print("reading files from {}".format(model_path))
|
print("reading files from {}".format(self.model_path))
|
||||||
config_path = os.path.join(model_path, 'config.json')
|
config_path = os.path.join(self.model_path, 'config.json')
|
||||||
vocab_path = os.path.join(model_path, 'vocab.json')
|
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
||||||
merges_path = os.path.join(model_path, 'merges.txt')
|
merges_path = os.path.join(self.model_path, 'merges.txt')
|
||||||
|
|
||||||
with open(config_path, 'r', encoding='utf8') as f:
|
with open(config_path, 'r', encoding='utf8') as f:
|
||||||
self.config = json.load(f)
|
self.config = json.load(f)
|
||||||
|
@ -24,7 +24,6 @@ class MinDalleBase:
|
||||||
with open(merges_path, 'r', encoding='utf8') as f:
|
with open(merges_path, 'r', encoding='utf8') as f:
|
||||||
merges = f.read().split("\n")[1:-1]
|
merges = f.read().split("\n")[1:-1]
|
||||||
|
|
||||||
self.model_params = load_dalle_bart_flax_params(model_path)
|
|
||||||
self.tokenizer = TextTokenizer(vocab, merges)
|
self.tokenizer = TextTokenizer(vocab, merges)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,15 @@ from .min_dalle_base import MinDalleBase
|
||||||
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
||||||
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
||||||
|
|
||||||
|
from .load_params import load_dalle_bart_flax_params
|
||||||
|
|
||||||
|
|
||||||
class MinDalleFlax(MinDalleBase):
|
class MinDalleFlax(MinDalleBase):
|
||||||
def __init__(self, is_mega: bool, is_reusable: bool = True):
|
def __init__(self, is_mega: bool, is_reusable: bool = True):
|
||||||
super().__init__(is_mega)
|
super().__init__(is_mega)
|
||||||
self.is_reusable = is_reusable
|
self.is_reusable = is_reusable
|
||||||
print("initializing MinDalleFlax")
|
print("initializing MinDalleFlax")
|
||||||
|
self.model_params = load_dalle_bart_flax_params(self.model_path)
|
||||||
if is_reusable:
|
if is_reusable:
|
||||||
self.init_encoder()
|
self.init_encoder()
|
||||||
self.init_decoder()
|
self.init_decoder()
|
||||||
|
|
|
@ -6,7 +6,10 @@ import torch
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_num_threads(os.cpu_count())
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
|
||||||
from .load_params import convert_dalle_bart_torch_from_flax_params
|
from .load_params import (
|
||||||
|
convert_and_save_mega_torch_params,
|
||||||
|
load_dalle_bart_flax_params
|
||||||
|
)
|
||||||
from .min_dalle_base import MinDalleBase
|
from .min_dalle_base import MinDalleBase
|
||||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||||
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||||
|
@ -19,10 +22,22 @@ class MinDalleTorch(MinDalleBase):
|
||||||
is_reusable: bool = True,
|
is_reusable: bool = True,
|
||||||
token_count: int = 256
|
token_count: int = 256
|
||||||
):
|
):
|
||||||
|
print("initializing MinDalleTorch")
|
||||||
super().__init__(is_mega)
|
super().__init__(is_mega)
|
||||||
self.is_reusable = is_reusable
|
self.is_reusable = is_reusable
|
||||||
self.token_count = token_count
|
self.token_count = token_count
|
||||||
print("initializing MinDalleTorch")
|
|
||||||
|
if not is_mega:
|
||||||
|
self.model_params = load_dalle_bart_flax_params(self.model_path)
|
||||||
|
|
||||||
|
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
|
||||||
|
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
|
||||||
|
|
||||||
|
is_converted = os.path.exists(self.encoder_params_path)
|
||||||
|
is_converted &= os.path.exists(self.decoder_params_path)
|
||||||
|
if not is_converted:
|
||||||
|
convert_and_save_mega_torch_params(is_mega, self.model_path)
|
||||||
|
|
||||||
if is_reusable:
|
if is_reusable:
|
||||||
self.init_encoder()
|
self.init_encoder()
|
||||||
self.init_decoder()
|
self.init_decoder()
|
||||||
|
@ -39,11 +54,7 @@ class MinDalleTorch(MinDalleBase):
|
||||||
text_token_count = self.config['max_text_length'],
|
text_token_count = self.config['max_text_length'],
|
||||||
glu_embed_count = self.config['encoder_ffn_dim']
|
glu_embed_count = self.config['encoder_ffn_dim']
|
||||||
)
|
)
|
||||||
params = convert_dalle_bart_torch_from_flax_params(
|
params = torch.load(self.encoder_params_path)
|
||||||
self.model_params.pop('encoder'),
|
|
||||||
layer_count=self.config['encoder_layers'],
|
|
||||||
is_encoder=True
|
|
||||||
)
|
|
||||||
self.encoder.load_state_dict(params, strict=False)
|
self.encoder.load_state_dict(params, strict=False)
|
||||||
del params
|
del params
|
||||||
if torch.cuda.is_available(): self.encoder = self.encoder.cuda()
|
if torch.cuda.is_available(): self.encoder = self.encoder.cuda()
|
||||||
|
@ -63,11 +74,7 @@ class MinDalleTorch(MinDalleBase):
|
||||||
start_token = self.config['decoder_start_token_id'],
|
start_token = self.config['decoder_start_token_id'],
|
||||||
is_verbose = True
|
is_verbose = True
|
||||||
)
|
)
|
||||||
params = convert_dalle_bart_torch_from_flax_params(
|
params = torch.load(self.decoder_params_path)
|
||||||
self.model_params.pop('decoder'),
|
|
||||||
layer_count=self.config['decoder_layers'],
|
|
||||||
is_encoder=False
|
|
||||||
)
|
|
||||||
self.decoder.load_state_dict(params, strict=False)
|
self.decoder.load_state_dict(params, strict=False)
|
||||||
del params
|
del params
|
||||||
if torch.cuda.is_available(): self.decoder = self.decoder.cuda()
|
if torch.cuda.is_available(): self.decoder = self.decoder.cuda()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user