|
|
|
@ -37,12 +37,12 @@ class MinDalleTorch(MinDalleBase): |
|
|
|
|
def init_encoder(self): |
|
|
|
|
print("initializing DalleBartEncoderTorch") |
|
|
|
|
self.encoder = DalleBartEncoderTorch( |
|
|
|
|
layer_count = self.config['encoder_layers'], |
|
|
|
|
embed_count = self.config['d_model'], |
|
|
|
|
attention_head_count = self.config['encoder_attention_heads'], |
|
|
|
|
text_vocab_count = self.config['encoder_vocab_size'], |
|
|
|
|
text_token_count = self.config['max_text_length'], |
|
|
|
|
glu_embed_count = self.config['encoder_ffn_dim'] |
|
|
|
|
attention_head_count = 32 if self.is_mega else 16, |
|
|
|
|
embed_count = 2048 if self.is_mega else 1024, |
|
|
|
|
glu_embed_count = 4096 if self.is_mega else 2730, |
|
|
|
|
text_token_count = 64, |
|
|
|
|
text_vocab_count = 50272 if self.is_mega else 50264, |
|
|
|
|
layer_count = 24 if self.is_mega else 12 |
|
|
|
|
) |
|
|
|
|
params = torch.load(self.encoder_params_path) |
|
|
|
|
self.encoder.load_state_dict(params, strict=False) |
|
|
|
@ -53,16 +53,15 @@ class MinDalleTorch(MinDalleBase): |
|
|
|
|
def init_decoder(self): |
|
|
|
|
print("initializing DalleBartDecoderTorch") |
|
|
|
|
self.decoder = DalleBartDecoderTorch( |
|
|
|
|
image_vocab_size = self.config['image_vocab_size'], |
|
|
|
|
image_token_count = self.config['image_length'], |
|
|
|
|
sample_token_count = self.token_count, |
|
|
|
|
embed_count = self.config['d_model'], |
|
|
|
|
attention_head_count = self.config['decoder_attention_heads'], |
|
|
|
|
glu_embed_count = self.config['decoder_ffn_dim'], |
|
|
|
|
layer_count = self.config['decoder_layers'], |
|
|
|
|
batch_count = 2, |
|
|
|
|
start_token = self.config['decoder_start_token_id'], |
|
|
|
|
is_verbose = True |
|
|
|
|
image_token_count = 256, |
|
|
|
|
image_vocab_count = 16415 if self.is_mega else 16384, |
|
|
|
|
attention_head_count = 32 if self.is_mega else 16, |
|
|
|
|
embed_count = 2048 if self.is_mega else 1024, |
|
|
|
|
glu_embed_count = 4096 if self.is_mega else 2730, |
|
|
|
|
layer_count = 24 if self.is_mega else 12, |
|
|
|
|
start_token = 16415 if self.is_mega else 16384, |
|
|
|
|
batch_count = 2 |
|
|
|
|
) |
|
|
|
|
params = torch.load(self.decoder_params_path) |
|
|
|
|
self.decoder.load_state_dict(params, strict=False) |
|
|
|
|