updated readme
This commit is contained in:
@@ -107,7 +107,7 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||
return P
|
||||
|
||||
|
||||
def convert_and_save_mega_torch_params(is_mega: bool, model_path: str):
|
||||
def convert_and_save_torch_params(is_mega: bool, model_path: str):
|
||||
print("converting params to torch")
|
||||
layer_count = 24 if is_mega else 12
|
||||
flax_params = load_dalle_bart_flax_params(model_path)
|
||||
|
@@ -7,7 +7,7 @@ torch.set_grad_enabled(False)
|
||||
torch.set_num_threads(os.cpu_count())
|
||||
|
||||
from .load_params import (
|
||||
convert_and_save_mega_torch_params,
|
||||
convert_and_save_torch_params,
|
||||
load_dalle_bart_flax_params
|
||||
)
|
||||
from .min_dalle_base import MinDalleBase
|
||||
@@ -36,7 +36,7 @@ class MinDalleTorch(MinDalleBase):
|
||||
is_converted = os.path.exists(self.encoder_params_path)
|
||||
is_converted &= os.path.exists(self.decoder_params_path)
|
||||
if not is_converted:
|
||||
convert_and_save_mega_torch_params(is_mega, self.model_path)
|
||||
convert_and_save_torch_params(is_mega, self.model_path)
|
||||
|
||||
if is_reusable:
|
||||
self.init_encoder()
|
||||
|
Reference in New Issue
Block a user