save converted detokenizer params
This commit is contained in:
parent
8b5960b687
commit
e4c2be54cb
|
@ -128,5 +128,9 @@ def convert_and_save_torch_params(is_mega: bool, model_path: str):
|
||||||
for i in encoder_params:
|
for i in encoder_params:
|
||||||
encoder_params[i] = encoder_params[i].to(torch.float16)
|
encoder_params[i] = encoder_params[i].to(torch.float16)
|
||||||
|
|
||||||
|
detoker_params = load_vqgan_torch_params('./pretrained/vqgan')
|
||||||
|
detoker_path = os.path.join('pretrained', 'vqgan', 'detokenizer.pt')
|
||||||
|
|
||||||
torch.save(encoder_params, os.path.join(model_path, 'encoder.pt'))
|
torch.save(encoder_params, os.path.join(model_path, 'encoder.pt'))
|
||||||
torch.save(decoder_params, os.path.join(model_path, 'decoder.pt'))
|
torch.save(decoder_params, os.path.join(model_path, 'decoder.pt'))
|
||||||
|
torch.save(detoker_params, detoker_path)
|
|
@ -3,7 +3,6 @@ import json
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from .text_tokenizer import TextTokenizer
|
from .text_tokenizer import TextTokenizer
|
||||||
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
|
|
||||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||||
|
|
||||||
class MinDalleBase:
|
class MinDalleBase:
|
||||||
|
@ -27,20 +26,12 @@ class MinDalleBase:
|
||||||
self.tokenizer = TextTokenizer(vocab, merges)
|
self.tokenizer = TextTokenizer(vocab, merges)
|
||||||
|
|
||||||
|
|
||||||
def init_detokenizer(self):
|
|
||||||
print("initializing VQGanDetokenizer")
|
|
||||||
params = load_vqgan_torch_params('./pretrained/vqgan')
|
|
||||||
self.detokenizer = VQGanDetokenizer()
|
|
||||||
self.detokenizer.load_state_dict(params)
|
|
||||||
del params
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_text(self, text: str) -> numpy.ndarray:
|
def tokenize_text(self, text: str) -> numpy.ndarray:
|
||||||
print("tokenizing text")
|
print("tokenizing text")
|
||||||
tokens = self.tokenizer.tokenize(text)
|
tokens = self.tokenizer.tokenize(text)
|
||||||
print("text tokens", tokens)
|
print("text tokens", tokens)
|
||||||
text_token_count = self.config['max_text_length']
|
text_token_count = self.config['max_text_length']
|
||||||
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
|
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
|
||||||
text_tokens[0, :len(tokens)] = tokens
|
text_tokens[0, :2] = [tokens[0], tokens[-1]]
|
||||||
text_tokens[1, :2] = [tokens[0], tokens[-1]]
|
text_tokens[1, :len(tokens)] = tokens
|
||||||
return text_tokens
|
return text_tokens
|
|
@ -6,8 +6,9 @@ import torch
|
||||||
from .min_dalle_base import MinDalleBase
|
from .min_dalle_base import MinDalleBase
|
||||||
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
||||||
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
||||||
|
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||||
|
|
||||||
from .load_params import load_dalle_bart_flax_params
|
from .load_params import load_dalle_bart_flax_params, load_vqgan_torch_params
|
||||||
|
|
||||||
|
|
||||||
class MinDalleFlax(MinDalleBase):
|
class MinDalleFlax(MinDalleBase):
|
||||||
|
@ -32,7 +33,7 @@ class MinDalleFlax(MinDalleBase):
|
||||||
text_vocab_count = self.config['encoder_vocab_size'],
|
text_vocab_count = self.config['encoder_vocab_size'],
|
||||||
layer_count = self.config['encoder_layers']
|
layer_count = self.config['encoder_layers']
|
||||||
).bind({'params': self.model_params.pop('encoder')})
|
).bind({'params': self.model_params.pop('encoder')})
|
||||||
|
|
||||||
|
|
||||||
def init_decoder(self):
|
def init_decoder(self):
|
||||||
print("initializing DalleBartDecoderFlax")
|
print("initializing DalleBartDecoderFlax")
|
||||||
|
@ -46,7 +47,14 @@ class MinDalleFlax(MinDalleBase):
|
||||||
layer_count = self.config['decoder_layers'],
|
layer_count = self.config['decoder_layers'],
|
||||||
start_token = self.config['decoder_start_token_id']
|
start_token = self.config['decoder_start_token_id']
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_detokenizer(self):
|
||||||
|
print("initializing VQGanDetokenizer")
|
||||||
|
params = load_vqgan_torch_params('./pretrained/vqgan')
|
||||||
|
self.detokenizer = VQGanDetokenizer()
|
||||||
|
self.detokenizer.load_state_dict(params)
|
||||||
|
del params
|
||||||
|
|
||||||
def generate_image(self, text: str, seed: int) -> Image.Image:
|
def generate_image(self, text: str, seed: int) -> Image.Image:
|
||||||
text_tokens = self.tokenize_text(text)
|
text_tokens = self.tokenize_text(text)
|
||||||
|
|
|
@ -6,13 +6,11 @@ import torch
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_num_threads(os.cpu_count())
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
|
||||||
from .load_params import (
|
from .load_params import convert_and_save_torch_params
|
||||||
convert_and_save_torch_params,
|
|
||||||
load_dalle_bart_flax_params
|
|
||||||
)
|
|
||||||
from .min_dalle_base import MinDalleBase
|
from .min_dalle_base import MinDalleBase
|
||||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||||
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||||
|
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||||
|
|
||||||
|
|
||||||
class MinDalleTorch(MinDalleBase):
|
class MinDalleTorch(MinDalleBase):
|
||||||
|
@ -26,15 +24,14 @@ class MinDalleTorch(MinDalleBase):
|
||||||
super().__init__(is_mega)
|
super().__init__(is_mega)
|
||||||
self.is_reusable = is_reusable
|
self.is_reusable = is_reusable
|
||||||
self.token_count = token_count
|
self.token_count = token_count
|
||||||
|
|
||||||
if not is_mega:
|
|
||||||
self.model_params = load_dalle_bart_flax_params(self.model_path)
|
|
||||||
|
|
||||||
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
|
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
|
||||||
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
|
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
|
||||||
|
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detokenizer.pt')
|
||||||
|
|
||||||
is_converted = os.path.exists(self.encoder_params_path)
|
is_converted = os.path.exists(self.encoder_params_path)
|
||||||
is_converted &= os.path.exists(self.decoder_params_path)
|
is_converted &= os.path.exists(self.decoder_params_path)
|
||||||
|
is_converted &= os.path.exists(self.detoker_params_path)
|
||||||
if not is_converted:
|
if not is_converted:
|
||||||
convert_and_save_torch_params(is_mega, self.model_path)
|
convert_and_save_torch_params(is_mega, self.model_path)
|
||||||
|
|
||||||
|
@ -79,11 +76,14 @@ class MinDalleTorch(MinDalleBase):
|
||||||
del params
|
del params
|
||||||
if torch.cuda.is_available(): self.decoder = self.decoder.cuda()
|
if torch.cuda.is_available(): self.decoder = self.decoder.cuda()
|
||||||
|
|
||||||
|
|
||||||
def init_detokenizer(self):
|
def init_detokenizer(self):
|
||||||
super().init_detokenizer()
|
print("initializing VQGanDetokenizer")
|
||||||
if torch.cuda.is_available():
|
self.detokenizer = VQGanDetokenizer()
|
||||||
self.detokenizer = self.detokenizer.cuda()
|
params = torch.load(self.detoker_params_path)
|
||||||
|
self.detokenizer.load_state_dict(params)
|
||||||
|
del params
|
||||||
|
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
|
||||||
|
|
||||||
|
|
||||||
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
||||||
|
|
|
@ -37,7 +37,8 @@ class DecoderSelfAttentionFlax(AttentionFlax):
|
||||||
state_index
|
state_index
|
||||||
)
|
)
|
||||||
batch_count = decoder_state.shape[0]
|
batch_count = decoder_state.shape[0]
|
||||||
keys, values = attention_state[:batch_count], attention_state[batch_count:]
|
keys = attention_state[:batch_count]
|
||||||
|
values = attention_state[batch_count:]
|
||||||
|
|
||||||
decoder_state = self.forward(
|
decoder_state = self.forward(
|
||||||
keys,
|
keys,
|
||||||
|
@ -120,7 +121,7 @@ class SampleState:
|
||||||
attention_state: jnp.ndarray
|
attention_state: jnp.ndarray
|
||||||
|
|
||||||
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
||||||
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
return (1 - a) * logits[0, -1] + a * logits[1, -1]
|
||||||
|
|
||||||
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
|
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
|
||||||
top_logits, _ = lax.top_k(logits, k)
|
top_logits, _ = lax.top_k(logits, k)
|
||||||
|
|
|
@ -184,7 +184,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
decoder_state = self.final_ln(decoder_state)
|
decoder_state = self.final_ln(decoder_state)
|
||||||
logits = self.lm_head(decoder_state)
|
logits = self.lm_head(decoder_state)
|
||||||
a = self.condition_factor
|
a = self.condition_factor
|
||||||
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
|
logits: FloatTensor = (1 - a) * logits[0, -1] + a * logits[1, -1]
|
||||||
|
|
||||||
top_logits, _ = logits.topk(50, dim=-1)
|
top_logits, _ = logits.topk(50, dim=-1)
|
||||||
probs = torch.where(
|
probs = torch.where(
|
||||||
|
|
5
setup.sh
vendored
5
setup.sh
vendored
|
@ -4,12 +4,11 @@ set -e
|
||||||
|
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
mkdir -p pretrained/vqgan
|
|
||||||
|
|
||||||
# download vqgan
|
# download vqgan
|
||||||
|
mkdir -p pretrained/vqgan
|
||||||
curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output ./pretrained/vqgan/flax_model.msgpack
|
curl https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/resolve/main/flax_model.msgpack -L --output ./pretrained/vqgan/flax_model.msgpack
|
||||||
|
|
||||||
# download dalle-mini and dalle mega
|
# download dalle-mini and dalle-mega
|
||||||
python -m wandb login --anonymously
|
python -m wandb login --anonymously
|
||||||
python -m wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
|
python -m wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
|
||||||
python -m wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14
|
python -m wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14
|
||||||
|
|
Loading…
Reference in New Issue
Block a user