separate setup processes for flax and torch

This commit is contained in:
Brett Kuprel
2022-07-01 11:08:33 -04:00
parent 7bf76deafb
commit 09a0f85b8e
5 changed files with 16 additions and 11 deletions

View File

@@ -6,7 +6,6 @@ import torch
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
from .load_params import convert_and_save_torch_params
from .min_dalle_base import MinDalleBase
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
@@ -29,12 +28,6 @@ class MinDalleTorch(MinDalleBase):
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt')
is_converted = os.path.exists(self.encoder_params_path)
is_converted &= os.path.exists(self.decoder_params_path)
is_converted &= os.path.exists(self.detoker_params_path)
if not is_converted:
convert_and_save_torch_params(is_mega, self.model_path)
if is_reusable:
self.init_encoder()
self.init_decoder()