diff --git a/min_dalle/load_params.py b/min_dalle/load_params.py index c51a3a9..bfdfd93 100644 --- a/min_dalle/load_params.py +++ b/min_dalle/load_params.py @@ -128,5 +128,9 @@ def convert_and_save_torch_params(is_mega: bool, model_path: str): for i in encoder_params: encoder_params[i] = encoder_params[i].to(torch.float16) + detoker_params = load_vqgan_torch_params('./pretrained/vqgan') + detoker_path = os.path.join('pretrained', 'vqgan', 'detokenizer.pt') + torch.save(encoder_params, os.path.join(model_path, 'encoder.pt')) - torch.save(decoder_params, os.path.join(model_path, 'decoder.pt')) \ No newline at end of file + torch.save(decoder_params, os.path.join(model_path, 'decoder.pt')) + torch.save(detoker_params, detoker_path) \ No newline at end of file diff --git a/min_dalle/min_dalle_base.py b/min_dalle/min_dalle_base.py index 1bde741..0d196b3 100644 --- a/min_dalle/min_dalle_base.py +++ b/min_dalle/min_dalle_base.py @@ -3,7 +3,6 @@ import json import numpy from .text_tokenizer import TextTokenizer -from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params from .models.vqgan_detokenizer import VQGanDetokenizer class MinDalleBase: @@ -27,20 +26,12 @@ class MinDalleBase: self.tokenizer = TextTokenizer(vocab, merges) - def init_detokenizer(self): - print("initializing VQGanDetokenizer") - params = load_vqgan_torch_params('./pretrained/vqgan') - self.detokenizer = VQGanDetokenizer() - self.detokenizer.load_state_dict(params) - del params - - def tokenize_text(self, text: str) -> numpy.ndarray: print("tokenizing text") tokens = self.tokenizer.tokenize(text) print("text tokens", tokens) text_token_count = self.config['max_text_length'] text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32) - text_tokens[0, :len(tokens)] = tokens - text_tokens[1, :2] = [tokens[0], tokens[-1]] + text_tokens[0, :2] = [tokens[0], tokens[-1]] + text_tokens[1, :len(tokens)] = tokens return text_tokens \ No newline at end of file diff --git a/min_dalle/min_dalle_flax.py b/min_dalle/min_dalle_flax.py index 176ce6b..5a367b2 100644 --- a/min_dalle/min_dalle_flax.py +++ b/min_dalle/min_dalle_flax.py @@ -6,8 +6,9 @@ import torch from .min_dalle_base import MinDalleBase from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax +from .models.vqgan_detokenizer import VQGanDetokenizer -from .load_params import load_dalle_bart_flax_params +from .load_params import load_dalle_bart_flax_params, load_vqgan_torch_params class MinDalleFlax(MinDalleBase): @@ -32,7 +33,7 @@ class MinDalleFlax(MinDalleBase): text_vocab_count = self.config['encoder_vocab_size'], layer_count = self.config['encoder_layers'] ).bind({'params': self.model_params.pop('encoder')}) - + def init_decoder(self): print("initializing DalleBartDecoderFlax") @@ -46,7 +47,14 @@ class MinDalleFlax(MinDalleBase): layer_count = self.config['decoder_layers'], start_token = self.config['decoder_start_token_id'] ) - + + + def init_detokenizer(self): + print("initializing VQGanDetokenizer") + params = load_vqgan_torch_params('./pretrained/vqgan') + self.detokenizer = VQGanDetokenizer() + self.detokenizer.load_state_dict(params) + del params def generate_image(self, text: str, seed: int) -> Image.Image: text_tokens = self.tokenize_text(text) diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index e5a0699..1593818 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -6,13 +6,11 @@ import torch torch.set_grad_enabled(False) torch.set_num_threads(os.cpu_count()) -from .load_params import ( - convert_and_save_torch_params, - load_dalle_bart_flax_params -) +from .load_params import convert_and_save_torch_params from .min_dalle_base import MinDalleBase from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch +from .models.vqgan_detokenizer import VQGanDetokenizer class MinDalleTorch(MinDalleBase): @@ -26,15 +24,14 @@ class MinDalleTorch(MinDalleBase): super().__init__(is_mega) self.is_reusable = is_reusable self.token_count = token_count - - if not is_mega: - self.model_params = load_dalle_bart_flax_params(self.model_path) self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt') self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt') + self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detokenizer.pt') is_converted = os.path.exists(self.encoder_params_path) is_converted &= os.path.exists(self.decoder_params_path) + is_converted &= os.path.exists(self.detoker_params_path) if not is_converted: convert_and_save_torch_params(is_mega, self.model_path) @@ -79,11 +76,14 @@ class MinDalleTorch(MinDalleBase): del params if torch.cuda.is_available(): self.decoder = self.decoder.cuda() - + def init_detokenizer(self): - super().init_detokenizer() - if torch.cuda.is_available(): - self.detokenizer = self.detokenizer.cuda() + print("initializing VQGanDetokenizer") + self.detokenizer = VQGanDetokenizer() + params = torch.load(self.detoker_params_path) + self.detokenizer.load_state_dict(params) + del params + if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() def generate_image_tokens(self, text: str, seed: int) -> LongTensor: diff --git a/min_dalle/models/dalle_bart_decoder_flax.py b/min_dalle/models/dalle_bart_decoder_flax.py index b7965a4..3ef4b3a 100644 --- a/min_dalle/models/dalle_bart_decoder_flax.py +++ b/min_dalle/models/dalle_bart_decoder_flax.py @@ -37,7 +37,8 @@ class DecoderSelfAttentionFlax(AttentionFlax): state_index ) batch_count = decoder_state.shape[0] - keys, values = attention_state[:batch_count], attention_state[batch_count:] + keys = attention_state[:batch_count] + values = attention_state[batch_count:] decoder_state = self.forward( keys, @@ -120,7 +121,7 @@ class SampleState: attention_state: jnp.ndarray def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray: - return a * logits[0, -1] + (1 - a) * logits[1, -1] + return (1 - a) * logits[0, -1] + a * logits[1, -1] def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray: top_logits, _ = lax.top_k(logits, k) diff --git a/min_dalle/models/dalle_bart_decoder_torch.py b/min_dalle/models/dalle_bart_decoder_torch.py index e6376fd..7b49344 100644 --- a/min_dalle/models/dalle_bart_decoder_torch.py +++ b/min_dalle/models/dalle_bart_decoder_torch.py @@ -184,7 +184,7 @@ class DalleBartDecoderTorch(nn.Module): decoder_state = self.final_ln(decoder_state) logits = self.lm_head(decoder_state) a = self.condition_factor - logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1] + logits: FloatTensor = (1 - a) * logits[0, -1] + a * logits[1, -1] top_logits, _ = logits.topk(50, dim=-1) probs = torch.where( diff --git a/setup.sh b/setup.sh index d405671..ee465bd 100644 --- a/setup.sh +++ b/setup.sh @@ -4,12 +4,11 @@ set -e pip install -r requirements.txt -mkdir -p pretrained/vqgan - # download vqgan +mkdir -p pretrained/vqgan curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output ./pretrained/vqgan/flax_model.msgpack -# download dalle-mini and dalle mega +# download dalle-mini and dalle-mega python -m wandb login --anonymously python -m wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0 python -m wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14