diff --git a/min_dalle/load_params.py b/min_dalle/load_params.py index ac4fad6..3876006 100644 --- a/min_dalle/load_params.py +++ b/min_dalle/load_params.py @@ -2,16 +2,17 @@ import os import numpy from copy import deepcopy from typing import Dict -from flax import traverse_util, serialization +from flax.traverse_util import flatten_dict +from flax.serialization import msgpack_restore import torch torch.set_grad_enabled(False) def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]: with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f: - params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read()) + params: Dict[str, numpy.ndarray] = msgpack_restore(f.read()) - P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.') + P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.') for i in list(P.keys()): j = i @@ -42,7 +43,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]: def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]: with open(os.path.join(path, "flax_model.msgpack"), "rb") as f: - params = serialization.msgpack_restore(f.read()) + params = msgpack_restore(f.read()) for codec in ['encoder', 'decoder']: k = 'FlaxBart{}Layers'.format(codec.title()) @@ -82,7 +83,7 @@ def convert_dalle_bart_torch_from_flax_params( is_encoder: bool ) -> dict: P = deepcopy(params) - P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(P, sep='.') + P: Dict[str, numpy.ndarray] = flatten_dict(P, sep='.') for i in P: P[i] = torch.tensor(P[i]) diff --git a/min_dalle/models/dalle_bart_decoder_torch.py b/min_dalle/models/dalle_bart_decoder_torch.py index 4b07beb..e6376fd 100644 --- a/min_dalle/models/dalle_bart_decoder_torch.py +++ b/min_dalle/models/dalle_bart_decoder_torch.py @@ -160,14 +160,15 @@ class DalleBartDecoderTorch(nn.Module): text_tokens: LongTensor, encoder_state: FloatTensor, attention_state: FloatTensor, - prev_token_and_index: LongTensor + prev_token: LongTensor, + token_index: LongTensor ) -> Tuple[LongTensor, FloatTensor]: attention_mask = text_tokens.not_equal(1) batch_count = encoder_state.shape[0] - prev_token = torch.cat([prev_token_and_index[:1]] * batch_count) - token_index = torch.cat([prev_token_and_index[1:]] * batch_count) - decoder_state = self.embed_tokens.forward(prev_token) - decoder_state += self.embed_positions.forward(token_index) + prev_token_batched = torch.cat([prev_token] * batch_count) + token_index_batched = torch.cat([token_index] * batch_count) + decoder_state = self.embed_tokens.forward(prev_token_batched) + decoder_state += self.embed_positions.forward(token_index_batched) decoder_state = self.layernorm_embedding.forward(decoder_state) decoder_state = decoder_state[:, None] attention_states_new = [] @@ -177,7 +178,7 @@ class DalleBartDecoderTorch(nn.Module): encoder_state, attention_state[i], attention_mask, - token_index[:1] + token_index ) attention_states_new.append(attention_state_layer) decoder_state = self.final_ln(decoder_state) @@ -185,7 +186,7 @@ class DalleBartDecoderTorch(nn.Module): a = self.condition_factor logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1] - top_logits = logits.sort(descending=True)[0][:50] + top_logits, _ = logits.topk(50, dim=-1) probs = torch.where( logits < top_logits[-1], self.zero_prob, @@ -206,12 +207,12 @@ class DalleBartDecoderTorch(nn.Module): image_token = self.start_token for i in range(self.sample_token_count): - token_index = self.token_indices[i:i+1] probs, attention_state = self.decode_step( text_tokens = text_tokens, encoder_state = encoder_state, attention_state = attention_state, - prev_token_and_index = torch.cat([image_token, token_index]) + prev_token = image_token, + token_index = self.token_indices[[i]] ) image_token = torch.multinomial(probs, 1) diff --git a/min_dalle/models/vqgan_detokenizer.py b/min_dalle/models/vqgan_detokenizer.py index b0b8758..1233046 100644 --- a/min_dalle/models/vqgan_detokenizer.py +++ b/min_dalle/models/vqgan_detokenizer.py @@ -61,6 +61,7 @@ class AttentionBlock(Module): h = self.proj_out.forward(h) return x + h + class MiddleLayer(Module): def __init__(self): super().__init__() @@ -74,6 +75,7 @@ class MiddleLayer(Module): h = self.block_2.forward(h) return h + class Upsample(Module): def __init__(self, log2_count): super().__init__() @@ -86,6 +88,7 @@ class Upsample(Module): x = self.conv.forward(x) return x + class UpsampleBlock(Module): def __init__( self, @@ -124,6 +127,7 @@ class UpsampleBlock(Module): h = self.upsample.forward(h) return h + class Decoder(Module): def __init__(self): super().__init__() @@ -154,6 +158,7 @@ class Decoder(Module): z = self.conv_out.forward(z) return z + class VQGanDetokenizer(Module): def __init__(self): super().__init__()