diff --git a/min_dalle/load_params.py b/min_dalle/load_params.py index 3876006..38deec2 100644 --- a/min_dalle/load_params.py +++ b/min_dalle/load_params.py @@ -82,11 +82,10 @@ def convert_dalle_bart_torch_from_flax_params( layer_count: int, is_encoder: bool ) -> dict: - P = deepcopy(params) - P: Dict[str, numpy.ndarray] = flatten_dict(P, sep='.') + P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.') for i in P: - P[i] = torch.tensor(P[i]) + P[i] = torch.tensor(P[i]).to(torch.float16) for i in list(P): if 'kernel' in i: diff --git a/min_dalle/min_dalle_flax.py b/min_dalle/min_dalle_flax.py index 8886f90..b60b538 100644 --- a/min_dalle/min_dalle_flax.py +++ b/min_dalle/min_dalle_flax.py @@ -28,7 +28,7 @@ class MinDalleFlax(MinDalleBase): text_token_count = self.config['max_text_length'], text_vocab_count = self.config['encoder_vocab_size'], layer_count = self.config['encoder_layers'] - ).bind({'params': self.model_params['encoder']}) + ).bind({'params': self.model_params.pop('encoder')}) def init_decoder(self): @@ -53,13 +53,17 @@ class MinDalleFlax(MinDalleBase): encoder_state = self.encoder(text_tokens) if self.is_expendable: del self.encoder - if self.is_expendable: self.init_decoder() + if self.is_expendable: + self.init_decoder() + params = self.model_params.pop('decoder') + else: + params = self.model_params['decoder'] print("sampling image tokens") image_tokens = self.decoder.sample_image_tokens( text_tokens, encoder_state, jax.random.PRNGKey(seed), - self.model_params['decoder'] + params ) if self.is_expendable: del self.decoder diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 0efffdf..bdfa662 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -40,12 +40,13 @@ class MinDalleTorch(MinDalleBase): glu_embed_count = self.config['encoder_ffn_dim'] ) params = convert_dalle_bart_torch_from_flax_params( - self.model_params['encoder'], + self.model_params.pop('encoder'), layer_count=self.config['encoder_layers'], is_encoder=True ) self.encoder.load_state_dict(params, strict=False) if torch.cuda.is_available(): self.encoder = self.encoder.cuda() + del params def init_decoder(self): @@ -63,12 +64,13 @@ class MinDalleTorch(MinDalleBase): is_verbose = True ) params = convert_dalle_bart_torch_from_flax_params( - self.model_params['decoder'], + self.model_params.pop('decoder'), layer_count=self.config['decoder_layers'], is_encoder=False ) self.decoder.load_state_dict(params, strict=False) if torch.cuda.is_available(): self.decoder = self.decoder.cuda() + del params def init_detokenizer(self):