diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..8470575 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +**/__pycache__/ +**/.cache/ +**/*.pkl +**/*.png +**/.DS* +*.pt +*.mlpackage +**/*.ckpt +.vscode +**/.ipynb_checkpoints +**/generated +**/pretrained +**/*.msgpack diff --git a/image_from_text.py b/image_from_text.py new file mode 100644 index 0000000..b77615f --- /dev/null +++ b/image_from_text.py @@ -0,0 +1,70 @@ +import os +import json +import numpy +import torch +from PIL import Image +from typing import Tuple, List + +from text_tokenizer import TextTokenizer +from models.vqgan_detokenizer import VQGanDetokenizer +from load_params import load_vqgan_torch_params + + +def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]: + print("loading model") + for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']: + assert(os.path.exists(os.path.join(path, f))) + with open(path + '/config.json', 'r') as f: + config = json.load(f) + with open(path + '/vocab.json') as f: + vocab = json.load(f) + with open(path + '/merges.txt') as f: + merges = f.read().split("\n")[1:-1] + return config, vocab, merges + + +def ascii_from_image(image: Image.Image, size: int) -> str: + rgb_pixels = image.resize((size, int(0.55 * size))).convert('L').getdata() + chars = list('.,;/IOX') + chars = [chars[i * len(chars) // 256] for i in rgb_pixels] + chars = [chars[i * size: (i + 1) * size] for i in range(size // 2)] + return '\n'.join(''.join(row) for row in chars) + + +def save_image(image: numpy.ndarray, path: str) -> Image.Image: + if os.path.isdir(path): + path = os.path.join(path, 'generated.png') + elif not path.endswith('.png'): + path += '.png' + print("saving image to", path) + image: Image.Image = Image.fromarray(numpy.asarray(image)) + image.save(path) + return image + + +def tokenize( + text: str, + config: dict, + vocab: dict, + merges: List[str] +) -> numpy.ndarray: + print("tokenizing text") + tokens = TextTokenizer(vocab, merges)(text) + print("text tokens", tokens) + text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32) + text_tokens[0, :len(tokens)] = tokens + text_tokens[1, :2] = [tokens[0], tokens[-1]] + return text_tokens + + +def detokenize_torch( + image_tokens: numpy.ndarray, + model_path: str +) -> numpy.ndarray: + print("detokenizing image") + params = load_vqgan_torch_params(model_path) + detokenizer = VQGanDetokenizer() + detokenizer.load_state_dict(params) + image_tokens = torch.tensor(image_tokens).to(torch.long) + image = detokenizer.forward(image_tokens).to(torch.uint8) + return image.detach().numpy() \ No newline at end of file diff --git a/image_from_text_flax.py b/image_from_text_flax.py new file mode 100644 index 0000000..3f6f56a --- /dev/null +++ b/image_from_text_flax.py @@ -0,0 +1,126 @@ +import jax +from jax import numpy as jnp +import numpy +import argparse + +from load_params import load_dalle_bart_flax_params +from image_from_text import ( + load_dalle_bart_metadata, + tokenize, + detokenize_torch, + save_image, + ascii_from_image +) +from models.dalle_bart_encoder_flax import DalleBartEncoderFlax +from models.dalle_bart_decoder_flax import DalleBartDecoderFlax + + +parser = argparse.ArgumentParser() +parser.add_argument( + '--text', + help='text to generate image from', + type=str +) +parser.add_argument( + '--seed', + help='random seed', + type=int, + default=0 +) +parser.add_argument( + '--image_path', + help='generated image path', + type=str, + default='generated.png' +) +parser.add_argument( + '--dalle_bart_path', + help='pretraied dalle bart path', + type=str, + default='./pretrained/dalle_bart_mini' +) +parser.add_argument( + '--vqgan_path', + help='pretraied vqgan path', + type=str, + default='./pretrained/vqgan' +) + + +def encode_flax( + text_tokens: numpy.ndarray, + config: dict, + params: dict +) -> jnp.ndarray: + print("loading flax encoder") + encoder: DalleBartEncoderFlax = DalleBartEncoderFlax( + attention_head_count = config['encoder_attention_heads'], + embed_count = config['d_model'], + glu_embed_count = config['encoder_ffn_dim'], + text_token_count = config['max_text_length'], + text_vocab_count = config['encoder_vocab_size'], + layer_count = config['encoder_layers'] + ).bind({'params': params.pop('encoder')}) + + print("encoding text tokens") + encoder_state = encoder(text_tokens) + del encoder + return encoder_state + +def decode_flax( + text_tokens: jnp.ndarray, + encoder_state: jnp.ndarray, + config: dict, + seed: int, + params: dict +) -> jnp.ndarray: + print("loading flax decoder") + decoder = DalleBartDecoderFlax( + image_token_count = config['image_length'], + text_token_count = config['max_text_length'], + image_vocab_count = config['image_vocab_size'], + attention_head_count = config['decoder_attention_heads'], + embed_count = config['d_model'], + glu_embed_count = config['decoder_ffn_dim'], + layer_count = config['decoder_layers'], + start_token = config['decoder_start_token_id'] + ) + print("sampling image tokens") + image_tokens = decoder.sample_image_tokens( + text_tokens, + encoder_state, + jax.random.PRNGKey(seed), + params.pop('decoder') + ) + del decoder + return image_tokens + +def generate_image_tokens_flax( + text: str, + seed: int, + dalle_bart_path: str +) -> numpy.ndarray: + config, vocab, merges = load_dalle_bart_metadata(dalle_bart_path) + text_tokens = tokenize(text, config, vocab, merges) + params_dalle_bart = load_dalle_bart_flax_params(dalle_bart_path) + encoder_state = encode_flax(text_tokens, config, params_dalle_bart) + image_tokens = decode_flax( + text_tokens, + encoder_state, + config, seed, + params_dalle_bart + ) + return numpy.array(image_tokens) + +if __name__ == '__main__': + args = parser.parse_args() + + image_tokens = generate_image_tokens_flax( + args.text, + args.seed, + args.dalle_bart_path + ) + print("image tokens", list(image_tokens)) + image = detokenize_torch(image_tokens, args.vqgan_path) + image = save_image(image, args.image_path) + print(ascii_from_image(image, size=128)) \ No newline at end of file diff --git a/image_from_text_torch.py b/image_from_text_torch.py new file mode 100644 index 0000000..43cfad4 --- /dev/null +++ b/image_from_text_torch.py @@ -0,0 +1,159 @@ +import numpy +import torch +from torch import Tensor +import argparse +from typing import Dict + +from image_from_text import ( + load_dalle_bart_metadata, + tokenize, + detokenize_torch, + save_image, + ascii_from_image +) +from models.dalle_bart_encoder_torch import DalleBartEncoderTorch +from models.dalle_bart_decoder_torch import DalleBartDecoderTorch + +from load_params import ( + load_dalle_bart_flax_params, + convert_dalle_bart_torch_from_flax_params +) + +parser = argparse.ArgumentParser() +parser.add_argument( + '--text', + help='text to generate image from', + type=str +) +parser.add_argument( + '--seed', + help='random seed', + type=int, + default=0 +) +parser.add_argument( + '--image_token_count', + help='image tokens to sample', + type=int, + default=256 +) +parser.add_argument( + '--image_path', + help='generated image path', + type=str, + default='generated.png' +) +parser.add_argument( + '--dalle_bart_path', + help='pretraied dalle bart path', + type=str, + default='./pretrained/dalle_bart_mini' +) +parser.add_argument( + '--vqgan_path', + help='pretraied vqgan path', + type=str, + default='./pretrained/vqgan' +) + + +def encode_torch( + text_tokens: numpy.ndarray, + config: dict, + params: dict +) -> Tensor: + print("loading torch encoder") + encoder = DalleBartEncoderTorch( + layer_count = config['encoder_layers'], + embed_count = config['d_model'], + attention_head_count = config['encoder_attention_heads'], + text_vocab_count = config['encoder_vocab_size'], + text_token_count = config['max_text_length'], + glu_embed_count = config['encoder_ffn_dim'] + ) + encoder_params = convert_dalle_bart_torch_from_flax_params( + params.pop('encoder'), + layer_count=config['encoder_layers'], + is_encoder=True + ) + encoder.load_state_dict(encoder_params, strict=False) + del encoder_params + + print("encoding text tokens") + text_tokens = torch.tensor(text_tokens).to(torch.long) + encoder_state = encoder(text_tokens) + del encoder + return encoder_state + + +def decode_torch( + text_tokens: Tensor, + encoder_state: Tensor, + config: dict, + seed: int, + params: dict, + image_token_count: int +) -> Tensor: + print("loading torch decoder") + decoder = DalleBartDecoderTorch( + image_vocab_size = config['image_vocab_size'], + image_token_count = config['image_length'], + sample_token_count = image_token_count, + embed_count = config['d_model'], + attention_head_count = config['decoder_attention_heads'], + glu_embed_count = config['decoder_ffn_dim'], + layer_count = config['decoder_layers'], + batch_count = 2, + start_token = config['decoder_start_token_id'], + is_verbose = True + ) + decoder_params = convert_dalle_bart_torch_from_flax_params( + params.pop('decoder'), + layer_count=config['decoder_layers'], + is_encoder=False + ) + decoder.load_state_dict(decoder_params, strict=False) + del decoder_params + + print("sampling image tokens") + torch.manual_seed(seed) + text_tokens = torch.tensor(text_tokens).to(torch.long) + image_tokens = decoder.forward(text_tokens, encoder_state) + return image_tokens + + +def generate_image_tokens_torch( + text: str, + seed: int, + image_token_count: int, + dalle_bart_path: str +) -> numpy.ndarray: + config, vocab, merges = load_dalle_bart_metadata(dalle_bart_path) + text_tokens = tokenize(text, config, vocab, merges) + params_dalle_bart = load_dalle_bart_flax_params(dalle_bart_path) + encoder_state = encode_torch(text_tokens, config, params_dalle_bart) + image_tokens = decode_torch( + text_tokens, + encoder_state, + config, seed, params_dalle_bart, + image_token_count + ) + return image_tokens.detach().numpy() + + +if __name__ == '__main__': + args = parser.parse_args() + image_tokens = generate_image_tokens_torch( + args.text, + args.seed, + args.image_token_count, + args.dalle_bart_path + ) + if args.image_token_count < 256: + print("image tokens", list(image_tokens, )) + else: + image = detokenize_torch(image_tokens, args.vqgan_path) + image = save_image(image, args.image_path) + print(ascii_from_image(image, size=128)) + + \ No newline at end of file diff --git a/load_params.py b/load_params.py new file mode 100644 index 0000000..b5c0332 --- /dev/null +++ b/load_params.py @@ -0,0 +1,111 @@ +import os +import numpy +from copy import deepcopy +from typing import Dict +import torch +from flax import traverse_util, serialization + + +def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]: + with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f: + params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read()) + + P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.') + + for i in list(P.keys()): + j = i + if 'up' in i or 'down' in i: + j = i.replace('_', '.') + j = j.replace('proj.out', 'proj_out') + j = j.replace('nin.short', 'nin_short') + if 'bias' in i: + P[j] = P.pop(i) + elif 'scale' in i: + j = j.replace('scale', 'weight') + P[j] = P.pop(i) + elif 'kernel' in i: + j = j.replace('kernel', 'weight') + P[j] = P.pop(i).transpose(3, 2, 0, 1) + + for i in P: + P[i] = torch.tensor(P[i]) + + P['embedding.weight'] = P.pop('quantize.embedding.embedding') + + for i in list(P): + if i.split('.')[0] in ['encoder', 'quant_conv']: + P.pop(i) + + return P + + +def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]: + with open(os.path.join(path, "flax_model.msgpack"), "rb") as f: + params = serialization.msgpack_restore(f.read()) + + for codec in ['encoder', 'decoder']: + k = 'FlaxBart{}Layers'.format(codec.title()) + P: dict = params['model'][codec]['layers'][k] + P['pre_self_attn_layer_norm'] = P.pop('LayerNorm_0') + P['self_attn_layer_norm'] = P.pop('LayerNorm_1') + P['self_attn'] = P.pop('FlaxBartAttention_0') + if codec == 'decoder': + P['pre_encoder_attn_layer_norm'] = P.pop('LayerNorm_2') + P['encoder_attn_layer_norm'] = P.pop('LayerNorm_3') + P['encoder_attn'] = P.pop('FlaxBartAttention_1') + P['glu']: dict = P.pop('GLU_0') + P['glu']['ln0'] = P['glu'].pop('LayerNorm_0') + P['glu']['ln1'] = P['glu'].pop('LayerNorm_1') + P['glu']['fc0'] = P['glu'].pop('Dense_0') + P['glu']['fc1'] = P['glu'].pop('Dense_1') + P['glu']['fc2'] = P['glu'].pop('Dense_2') + + for codec in ['encoder', 'decoder']: + layers_params = params['model'][codec].pop('layers') + params['model'][codec] = { + **params['model'][codec], + **layers_params + } + + model_params = params.pop('model') + params = {**params, **model_params} + + params['decoder']['lm_head'] = params.pop('lm_head') + + return params + + +def convert_dalle_bart_torch_from_flax_params( + params: dict, + layer_count: int, + is_encoder: bool +) -> dict: + P = deepcopy(params) + P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(P, sep='.') + + for i in P: + P[i] = torch.tensor(P[i]) + + for i in list(P): + if 'kernel' in i: + j = i.replace('kernel', 'weight') + P[j] = P.pop(i).transpose(-1, -2) + elif 'scale' in i: + j = i.replace('scale', 'weight') + P[j] = P.pop(i) + + for i in list(P): + j = 'FlaxBart{}Layers'.format('Encoder' if is_encoder else 'Decoder') + if j in i: + for l in range(layer_count): + k = i.replace(j, 'layers.' + str(l)) + P[k] = P[i][l] + P.pop(i) + + for i in list(P): + if '_proj' in i: + P[i] = P[i][:, :, None, None] + + P['embed_tokens.weight'] = P.pop('embed_tokens.embedding') + P['embed_positions.weight'] = P.pop('embed_positions.embedding') + return P \ No newline at end of file diff --git a/models/dalle_bart_decoder_flax.py b/models/dalle_bart_decoder_flax.py new file mode 100644 index 0000000..abf4971 --- /dev/null +++ b/models/dalle_bart_decoder_flax.py @@ -0,0 +1,288 @@ +import jax, flax +from jax import lax, numpy as jnp +from flax import linen as nn +from typing import Tuple + +from .dalle_bart_encoder_flax import GLUFlax, AttentionFlax + + +class DecoderCrossAttentionFlax(AttentionFlax): + def __call__( + self, + decoder_state: jnp.ndarray, + encoder_state: jnp.ndarray, + attention_mask: jnp.ndarray, + ) -> jnp.ndarray: + keys: jnp.ndarray = self.k_proj(encoder_state) + values: jnp.ndarray = self.v_proj(encoder_state) + queries: jnp.ndarray = self.q_proj(decoder_state) + query_shape = queries.shape[:2] + (self.head_count, -1) + key_value_shape = keys.shape[:2] + (self.head_count, -1) + keys = keys.reshape(key_value_shape) + values = values.reshape(key_value_shape) + queries = queries.reshape(query_shape) + queries /= queries.shape[-1] ** 0.5 + return self.forward(keys, values, queries, attention_mask) + + +class DecoderSelfAttentionFlax(AttentionFlax): + def __call__(self, + decoder_state: jnp.ndarray, + keys_state: jnp.ndarray, + values_state: jnp.ndarray, + attention_mask: jnp.ndarray, + state_index: tuple + ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: + shape_split = decoder_state.shape[:2] + (self.head_count, -1) + keys_state = lax.dynamic_update_slice( + keys_state, + self.k_proj(decoder_state).reshape(shape_split), + state_index + ) + values_state = lax.dynamic_update_slice( + values_state, + self.v_proj(decoder_state).reshape(shape_split), + state_index + ) + queries = self.q_proj(decoder_state).reshape(shape_split) + queries /= queries.shape[-1] ** 0.5 + decoder_state = self.forward( + keys_state, + values_state, + queries, + attention_mask + ) + return decoder_state, (keys_state, values_state) + + +class DalleBartDecoderLayerFlax(nn.Module): + image_token_count: int + attention_head_count: int + embed_count: int + glu_embed_count: int + + def setup(self): + self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False) + self.self_attn = DecoderSelfAttentionFlax( + self.attention_head_count, + self.embed_count + ) + self.self_attn_layer_norm = nn.LayerNorm() + self.pre_encoder_attn_layer_norm = nn.LayerNorm(use_scale=False) + self.encoder_attn = DecoderCrossAttentionFlax( + self.attention_head_count, + self.embed_count, + ) + self.encoder_attn_layer_norm = nn.LayerNorm() + self.glu = GLUFlax(self.embed_count, self.glu_embed_count) + + @nn.compact + def __call__(self, + decoder_state: jnp.ndarray, + encoder_state: jnp.ndarray, + keys_state: jnp.ndarray, + values_state: jnp.ndarray, + attention_mask: jnp.ndarray, + token_index: int + ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: + # Self Attention + residual = decoder_state + decoder_state = self.pre_self_attn_layer_norm(decoder_state) + self_attention_mask = jnp.tile( + jnp.arange(self.image_token_count) < token_index + 1, + (decoder_state.shape[0], 1) + ) + decoder_state, keys_values_state = self.self_attn( + decoder_state, + keys_state, + values_state, + self_attention_mask, + (0, token_index, 0, 0) + ) + decoder_state = self.self_attn_layer_norm(decoder_state) + decoder_state = residual + decoder_state + + # Cross Attention + residual = decoder_state + decoder_state = self.pre_encoder_attn_layer_norm(decoder_state) + decoder_state = self.encoder_attn( + decoder_state, + encoder_state, + attention_mask + ) + decoder_state = self.encoder_attn_layer_norm(decoder_state) + decoder_state = residual + decoder_state + + # Feed forward + residual = decoder_state + decoder_state = self.glu(decoder_state) + decoder_state = residual + decoder_state + + return decoder_state, keys_values_state + + +@flax.struct.dataclass +class SampleState: + prev_token: jnp.ndarray + prng_key: jnp.ndarray + keys_state: jnp.ndarray + values_state: jnp.ndarray + +def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray: + return a * logits[0, -1] + (1 - a) * logits[1, -1] + +def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray: + top_logits, top_tokens = lax.top_k(logits, k) + suppressed = -jnp.inf * jnp.ones_like(logits) + return lax.select(logits < top_logits[-1], suppressed, logits) + +class DalleBartDecoderFlax(nn.Module): + image_token_count: int + text_token_count: int + image_vocab_count: int + attention_head_count: int + embed_count: int + glu_embed_count: int + layer_count: int + start_token: int + + def setup(self): + self.embed_tokens = nn.Embed( + self.image_vocab_count + 1, + self.embed_count + ) + self.embed_positions = nn.Embed( + self.image_token_count, + self.embed_count + ) + self.layers = nn.scan( + DalleBartDecoderLayerFlax, + variable_axes = { "params": 0, "cache": 0 }, + split_rngs = { "params": True }, + in_axes = (nn.broadcast, 0, 0, nn.broadcast, nn.broadcast), + out_axes = (0, 0), + length=self.layer_count, + )( + self.image_token_count, + self.attention_head_count, + self.embed_count, + self.glu_embed_count, + name="FlaxBartDecoderLayers" + ) + self.layernorm_embedding = nn.LayerNorm() + self.final_ln = nn.LayerNorm(use_scale=False) + self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False) + + def __call__(self, + encoder_state: jnp.ndarray, + keys_state: jnp.ndarray, + values_state: jnp.ndarray, + attention_mask: jnp.ndarray, + prev_token: int, + token_index: int + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + batch_count = encoder_state.shape[0] + ones = jnp.ones((batch_count, 1), dtype=jnp.int32) + decoder_state = self.embed_tokens(prev_token * ones) + decoder_state += self.embed_positions(token_index * ones) + decoder_state = self.layernorm_embedding(decoder_state) + decoder_state, (keys_state, values_state) = self.layers( + decoder_state, + encoder_state, + keys_state, + values_state, + attention_mask, + token_index + ) + decoder_state = self.final_ln(decoder_state) + decoder_state = self.lm_head(decoder_state) + return decoder_state, keys_state, values_state + + def compute_logits(self, + text_tokens: jnp.ndarray, + encoder_state: jnp.ndarray, + params: dict + ) -> jnp.ndarray: + batch_count = encoder_state.shape[0] + state_shape = ( + self.layer_count, + batch_count, + self.image_token_count, + self.attention_head_count, + self.embed_count // self.attention_head_count + ) + keys_state = jnp.zeros(state_shape) + values_state = jnp.zeros(state_shape) + + logits, _, _ = self.apply( + { 'params': params }, + encoder_state = encoder_state, + keys_state = keys_state, + values_state = values_state, + attention_mask = jnp.not_equal(text_tokens, 1), + prev_token = self.start_token, + token_index = 0 + ) + + return super_conditioned(logits, 10.0) + + def sample_image_tokens(self, + text_tokens: jnp.ndarray, + encoder_state: jnp.ndarray, + prng_key: jax.random.PRNGKey, + params: dict + ) -> jnp.ndarray: + attention_mask = jnp.not_equal(text_tokens, 1) + + def sample_next_image_token( + state: SampleState, + token_index: int + ) -> Tuple[SampleState, None]: + logits, keys_state, values_state = self.apply( + { 'params': params }, + encoder_state = encoder_state, + keys_state = state.keys_state, + values_state = state.values_state, + attention_mask = attention_mask, + prev_token = state.prev_token, + token_index = token_index + ) + + logits = super_conditioned(logits, 10.0) + logits = keep_top_k(logits, k=50) + + prng_key, prng_key_next = jax.random.split(state.prng_key) + next_token = jax.random.categorical(prng_key, logits, axis=-1) + + state = SampleState( + prev_token = next_token, + prng_key = prng_key_next, + keys_state = keys_state, + values_state = values_state + ) + + return state, next_token + + batch_count = encoder_state.shape[0] + state_shape = ( + self.layer_count, + batch_count, + self.image_token_count, + self.attention_head_count, + self.embed_count // self.attention_head_count + ) + + initial_state = SampleState( + prev_token = self.start_token, + prng_key = prng_key, + keys_state = jnp.zeros(state_shape), + values_state = jnp.zeros(state_shape) + ) + + _, image_tokens = lax.scan( + sample_next_image_token, + initial_state, + jnp.arange(self.image_token_count) + ) + + return image_tokens \ No newline at end of file diff --git a/models/dalle_bart_decoder_torch.py b/models/dalle_bart_decoder_torch.py new file mode 100644 index 0000000..e2a79e5 --- /dev/null +++ b/models/dalle_bart_decoder_torch.py @@ -0,0 +1,214 @@ +import torch +from torch import LongTensor, nn, FloatTensor, BoolTensor +from typing import List, Tuple + +from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch + + +class DecoderCrossAttentionTorch(AttentionTorch): + def forward( + self, + decoder_state: FloatTensor, + encoder_state: FloatTensor, + attention_mask: BoolTensor + ) -> FloatTensor: + keys = self.k_proj.forward(encoder_state) + values = self.v_proj.forward(encoder_state) + queries = self.q_proj.forward(decoder_state) + return super().forward(keys, values, queries, attention_mask) + + +class DecoderSelfAttentionTorch(AttentionTorch): + def forward(self, + decoder_state: FloatTensor, + keys_values: FloatTensor, + attention_mask: BoolTensor, + token_index: LongTensor + ) -> Tuple[FloatTensor, FloatTensor, FloatTensor]: + keys = self.k_proj.forward(decoder_state) + values = self.v_proj.forward(decoder_state) + queries = self.q_proj.forward(decoder_state) + + batch_count = decoder_state.shape[0] + token_count = keys_values.shape[-1] + keys_values = torch.where( + (torch.arange(token_count) == token_index)[None, None, :], + torch.cat([keys, values]).squeeze(2), + keys_values + ) + keys, values = keys_values[:batch_count, :, None], keys_values[batch_count:, :, None] + + decoder_state = super().forward(keys, values, queries, attention_mask) + return decoder_state, keys_values + + +class DecoderLayerTorch(nn.Module): + def __init__(self, + image_token_count: int, + head_count: int, + embed_count: int, + glu_embed_count: int + ): + super().__init__() + self.image_token_count = image_token_count + self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) + self.self_attn = DecoderSelfAttentionTorch(head_count, embed_count) + self.self_attn_layer_norm = nn.LayerNorm(embed_count) + self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count) + self.encoder_attn = DecoderCrossAttentionTorch(head_count, embed_count) + self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) + self.glu = GLUTorch(embed_count, glu_embed_count) + + def forward(self, + decoder_state: FloatTensor, + encoder_state: FloatTensor, + keys_values_state: FloatTensor, + attention_mask: BoolTensor, + token_index: LongTensor + ) -> Tuple[FloatTensor, FloatTensor]: + # Self Attention + residual = decoder_state + decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) + self_attn_mask = torch.arange(self.image_token_count) < token_index + 1 + self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0]) + decoder_state = decoder_state.transpose(1, 2).unsqueeze(2) + # print("decoder_state", decoder_state.shape) + decoder_state, keys_values_state = self.self_attn.forward( + decoder_state, + keys_values_state, + self_attn_mask, + token_index + ) + decoder_state = decoder_state.transpose(1, 3).squeeze(2) + decoder_state = self.self_attn_layer_norm.forward(decoder_state) + decoder_state = residual + decoder_state + + # Cross Attention + residual = decoder_state + decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state) + decoder_state = decoder_state.transpose(1, 2).unsqueeze(2) + decoder_state = self.encoder_attn.forward( + decoder_state, + encoder_state, + attention_mask + ) + decoder_state = decoder_state.transpose(1, 3).squeeze(2) + decoder_state = self.encoder_attn_layer_norm.forward(decoder_state) + decoder_state = residual + decoder_state + + # Feed forward + residual = decoder_state + decoder_state = self.glu.forward(decoder_state) + decoder_state = residual + decoder_state + + return decoder_state, keys_values_state + + +class DalleBartDecoderTorch(nn.Module): + def __init__(self, + image_vocab_size: int, + image_token_count: int, + sample_token_count: int, + embed_count: int, + attention_head_count: int, + glu_embed_count: int, + layer_count: int, + batch_count: int, + start_token: int, + is_verbose: bool + ): + super().__init__() + self.is_verbose = is_verbose + self.layer_count = layer_count + self.sample_token_count = sample_token_count + self.start_token = torch.tensor([start_token]).to(torch.long) + self.pad_token = torch.tensor([1]).to(torch.long) + self.condition_factor = torch.tensor([10]).to(torch.float) + self.image_token_count = image_token_count + self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count) + self.embed_positions = nn.Embedding(image_token_count, embed_count) + self.layers: List[DecoderLayerTorch] = nn.ModuleList([ + DecoderLayerTorch( + image_token_count, + attention_head_count, + embed_count, + glu_embed_count + ) + for _ in range(layer_count) + ]) + self.layernorm_embedding = nn.LayerNorm(embed_count) + self.final_ln = nn.LayerNorm(embed_count) + self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False) + self.keys_values_state_shape = ( + layer_count * 2 * batch_count, + embed_count, + image_token_count + ) + + + def decode_step(self, + text_tokens: LongTensor, + encoder_state: FloatTensor, + keys_values_state: FloatTensor, + prev_token_and_index: LongTensor + ) -> Tuple[LongTensor, FloatTensor]: + attention_mask = text_tokens.not_equal(self.pad_token) + batch_count = encoder_state.shape[0] + prev_token = torch.cat([prev_token_and_index[:1]] * batch_count) + token_index = torch.cat([prev_token_and_index[1:]] * batch_count) + decoder_state = self.embed_tokens.forward(prev_token) + decoder_state += self.embed_positions.forward(token_index) + decoder_state = self.layernorm_embedding.forward(decoder_state) + decoder_state = decoder_state[:, None] # (batch_count, 1, embed_count) + keys_values = [] + for i, layer in enumerate(self.layers): + j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count + decoder_state, keys_values_layer = layer.forward( + decoder_state, + encoder_state, + keys_values_state[j1:j2], + attention_mask, + token_index[:1] + ) + keys_values.append(keys_values_layer) + keys_values = torch.cat(keys_values, dim=0) + decoder_state = self.final_ln(decoder_state) # (batch_count, 1, embed_count) + logits = self.lm_head(decoder_state) # (batch_count, 1, vocab_size) + a = self.condition_factor + logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1] + + top_logits = logits.sort(descending=True)[0][:50] + probs = torch.where( + logits < top_logits[-1], + torch.zeros([1]), + torch.exp(logits - top_logits[0]) + ) + return probs, keys_values + + + def forward(self, + text_tokens: LongTensor, + encoder_state: FloatTensor + ) -> LongTensor: + image_tokens: List[LongTensor] = [] + keys_values_state = torch.zeros(self.keys_values_state_shape) + image_token = self.start_token + encoder_state = encoder_state.transpose(1, 2).unsqueeze(2) + + for i in range(self.sample_token_count): + token_index = torch.tensor([i]).to(torch.long) + probs, keys_values_state = self.decode_step( + text_tokens = text_tokens, + encoder_state = encoder_state, + keys_values_state = keys_values_state, + prev_token_and_index = torch.cat([image_token, token_index]) + ) + + image_token = torch.multinomial(probs, 1) + image_tokens += [image_token] + + if self.is_verbose: + token = int(image_token.detach().numpy()) + print("image token {} is {}".format(i, token)) + + return torch.cat(image_tokens) \ No newline at end of file diff --git a/models/dalle_bart_encoder_flax.py b/models/dalle_bart_encoder_flax.py new file mode 100644 index 0000000..8eff67a --- /dev/null +++ b/models/dalle_bart_encoder_flax.py @@ -0,0 +1,143 @@ +from functools import partial +import jax +from jax import lax, numpy as jnp +from flax import linen as nn + + +class GLUFlax(nn.Module): + count_in_out: int + count_middle: int + + def setup(self): + self.gelu = partial(nn.gelu, approximate=False) + self.ln0 = nn.LayerNorm(use_scale=False) + self.ln1 = nn.LayerNorm(use_scale=False) + self.fc0 = nn.Dense(self.count_middle, use_bias=False) + self.fc1 = nn.Dense(self.count_middle, use_bias=False) + self.fc2 = nn.Dense(self.count_in_out, use_bias=False) + + @nn.compact + def __call__(self, z: jnp.ndarray) -> jnp.ndarray: + z = self.ln0(z) + z = self.ln1(self.gelu(self.fc0(z)) * self.fc1(z)) + z = self.fc2(z) + return z + +class AttentionFlax(nn.Module): + head_count: int + embed_count: int + + def setup(self): + self.q_proj = nn.Dense(self.embed_count, use_bias=False) + self.k_proj = nn.Dense(self.embed_count, use_bias=False) + self.v_proj = nn.Dense(self.embed_count, use_bias=False) + self.out_proj = nn.Dense(self.embed_count, use_bias=False) + + def forward(self, + keys: jnp.ndarray, + values: jnp.ndarray, + queries: jnp.ndarray, + attention_mask: jnp.ndarray + ) -> jnp.ndarray: + attention_bias: jnp.ndarray = lax.select( + attention_mask, + jnp.full(attention_mask.shape, 0.0), + jnp.full(attention_mask.shape, -jnp.inf), + ) + attention_weights: jnp.ndarray = jnp.einsum( + 'bqhd,bkhd->bhqk', + queries, + keys + ) + attention_weights += attention_bias[:, None, None, :] + attention_weights = jax.nn.softmax(attention_weights) + attention_output: jnp.ndarray = jnp.einsum( + "bhqk,bkhd->bqhd", + attention_weights, + values + ) + shape = attention_output.shape[:2] + (self.embed_count,) + attention_output = attention_output.reshape(shape) + attention_output = self.out_proj(attention_output) + return attention_output + +class EncoderSelfAttentionFlax(AttentionFlax): + def __call__( + self, + encoder_state: jnp.ndarray, + attention_mask: jnp.ndarray + ) -> jnp.ndarray: + shape_split = encoder_state.shape[:2] + (self.head_count, -1) + keys = self.k_proj(encoder_state).reshape(shape_split) + values = self.v_proj(encoder_state).reshape(shape_split) + queries = self.q_proj(encoder_state).reshape(shape_split) + queries /= queries.shape[-1] ** 0.5 + return self.forward(keys, values, queries, attention_mask) + +class DalleBartEncoderLayerFlax(nn.Module): + attention_head_count: int + embed_count: int + glu_embed_count: int + + def setup(self): + self.pre_self_attn_layer_norm = nn.LayerNorm(use_scale=False) + self.self_attn = EncoderSelfAttentionFlax( + self.attention_head_count, + self.embed_count + ) + self.self_attn_layer_norm = nn.LayerNorm() + self.glu = GLUFlax(self.embed_count, self.glu_embed_count) + + @nn.compact + def __call__(self, + encoder_state: jnp.ndarray, + attention_mask: jnp.ndarray + ) -> jnp.ndarray: + residual = encoder_state + encoder_state = self.pre_self_attn_layer_norm(encoder_state) + encoder_state = self.self_attn(encoder_state, attention_mask) + encoder_state = self.self_attn_layer_norm(encoder_state) + encoder_state = residual + encoder_state + residual = encoder_state + encoder_state = self.glu(encoder_state) + encoder_state = residual + encoder_state + return encoder_state, None + +class DalleBartEncoderFlax(nn.Module): + attention_head_count: int + embed_count: int + glu_embed_count: int + text_token_count: int + text_vocab_count: int + layer_count: int + + def setup(self): + self.embed_tokens = nn.Embed(self.text_vocab_count, self.embed_count) + self.embed_positions = nn.Embed(self.text_token_count, self.embed_count) + self.layers = nn.scan( + DalleBartEncoderLayerFlax, + variable_axes = { "params": 0, "cache": 0 }, + split_rngs = { "params": True }, + in_axes = nn.broadcast, + length = self.layer_count + )( + self.attention_head_count, + self.embed_count, + self.glu_embed_count, + name="FlaxBartEncoderLayers" + ) + self.layernorm_embedding = nn.LayerNorm() + self.final_ln = nn.LayerNorm(use_scale=False) + + def __call__(self, text_tokens: jnp.ndarray) -> jnp.ndarray: + batch_count, token_count = text_tokens.shape + pose_tokens = jnp.tile(jnp.arange(token_count), (batch_count, 1)) + attention_mask = jnp.not_equal(text_tokens, 1) + encoder_state = ( + self.embed_tokens(text_tokens) + + self.embed_positions(pose_tokens) + ) + encoder_state = self.layernorm_embedding(encoder_state) + encoder_state, _ = self.layers(encoder_state, attention_mask) + encoder_state = self.final_ln(encoder_state) + return encoder_state \ No newline at end of file diff --git a/models/dalle_bart_encoder_torch.py b/models/dalle_bart_encoder_torch.py new file mode 100644 index 0000000..9e481f3 --- /dev/null +++ b/models/dalle_bart_encoder_torch.py @@ -0,0 +1,159 @@ +from typing import List +import torch +from torch import nn, BoolTensor, FloatTensor, LongTensor + +class GLUTorch(nn.Module): + def __init__(self, count_in_out, count_middle): + super().__init__() + self.gelu = nn.GELU() + self.ln0 = nn.LayerNorm(count_in_out) + self.ln1 = nn.LayerNorm(count_middle) + self.fc0 = nn.Linear(count_in_out, count_middle, bias=False) + self.fc1 = nn.Linear(count_in_out, count_middle, bias=False) + self.fc2 = nn.Linear(count_middle, count_in_out, bias=False) + + def forward(self, z: FloatTensor) -> FloatTensor: + z = self.ln0.forward(z) + w = self.fc0.forward(z) + w = self.gelu.forward(w) + v = self.fc1.forward(z) + z = self.ln1.forward(w * v) + z = self.fc2.forward(z) + return z + +class AttentionTorch(nn.Module): + def __init__(self, head_count: int, embed_count: int): + super().__init__() + self.head_count = head_count + self.embed_count = embed_count + self.head_dim = embed_count // head_count + + self.k_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False) + self.v_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False) + self.q_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False) + self.out_proj = nn.Conv2d(embed_count, embed_count, 1, bias=False) + + def forward(self, + keys: FloatTensor, + values: FloatTensor, + queries: FloatTensor, + attention_mask: BoolTensor + ) -> FloatTensor: + batch_count = keys.shape[0] + + # b(hc)1q -> bqhc + # print(keys.shape, "keys", values.shape, "values", queries.shape, "queries") + keys = keys.transpose(1, 3) + keys = keys.reshape(keys.shape[:2] + (self.head_count, -1)) + + # b(hc)1q -> bchq + shape = (batch_count, self.head_count, self.head_dim, -1) + values = values.reshape(shape) + values = values.transpose(1, 2) + queries = queries.reshape(shape) + queries = queries.transpose(1, 2) + + # print(keys.shape, "keys", values.shape, "values", queries.shape, "queries") + + attention_bias = torch.where( + attention_mask, + torch.zeros([1, 1]), + torch.ones([1, 1]) * (-torch.inf), + ) + attention_weights: FloatTensor = torch.einsum( + 'bchq,bkhc->bkhq', + queries / self.head_dim ** 0.5, + keys + ) + attention_weights += attention_bias[:, :, None, None] + attention_weights = torch.softmax(attention_weights, 1) + # print(attention_weights.shape, "attention_weights") + hidden_state: FloatTensor = torch.einsum( + "bkhq,bchk->bchq", + attention_weights, + values + ) + # bchq -> b(hc)1q + # print(hidden_state.shape, "hidden_state") + hidden_state = hidden_state.transpose(1, 2) + hidden_state = hidden_state.reshape(batch_count, self.embed_count, 1, -1) + hidden_state = self.out_proj.forward(hidden_state) + # print(hidden_state.shape, "hidden_state") + return hidden_state + + +class EncoderSelfAttentionTorch(AttentionTorch): + def forward( + self, + encoder_state: FloatTensor, + attention_mask: BoolTensor + ) -> FloatTensor: + encoder_state = encoder_state.transpose(1, 2).unsqueeze(2) + # print(encoder_state.shape, "encoder_state") + keys = self.k_proj.forward(encoder_state) + values = self.v_proj.forward(encoder_state) + queries = self.q_proj.forward(encoder_state) + return super().forward(keys, values, queries, attention_mask) + + +class EncoderLayerTorch(nn.Module): + def __init__(self, embed_count: int, head_count: int, glu_embed_count: int): + super().__init__() + self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) + self.self_attn = EncoderSelfAttentionTorch(head_count, embed_count) + self.self_attn_layer_norm = nn.LayerNorm(embed_count) + self.glu = GLUTorch(embed_count, glu_embed_count) + + def forward( + self, + encoder_state: FloatTensor, + attention_mask: BoolTensor + ) -> FloatTensor: + residual = encoder_state + encoder_state = self.pre_self_attn_layer_norm.forward(encoder_state) + encoder_state = self.self_attn.forward(encoder_state, attention_mask) + encoder_state = encoder_state.transpose(1, 3).squeeze(2) + encoder_state = self.self_attn_layer_norm.forward(encoder_state) + encoder_state = residual + encoder_state + residual = encoder_state + encoder_state = self.glu.forward(encoder_state) + encoder_state = residual + encoder_state + return encoder_state + + +class DalleBartEncoderTorch(nn.Module): + def __init__(self, + layer_count: int, + embed_count: int, + attention_head_count: int, + text_vocab_count: int, + text_token_count: int, + glu_embed_count: int + ): + super().__init__() + self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) + self.embed_positions = nn.Embedding(text_token_count, embed_count) + self.layers: List[EncoderLayerTorch] = nn.ModuleList([ + EncoderLayerTorch( + embed_count = embed_count, + head_count = attention_head_count, + glu_embed_count = glu_embed_count + ) + for _ in range(layer_count) + ]) + self.layernorm_embedding = nn.LayerNorm(embed_count) + self.final_ln = nn.LayerNorm(embed_count) + + def forward(self, text_tokens: LongTensor) -> FloatTensor: + attention_mask = text_tokens.not_equal(1) + batch_count, token_count = text_tokens.shape + pose_tokens = torch.stack([torch.arange(token_count)] * batch_count) + encoder_state = ( + self.embed_tokens.forward(text_tokens) + + self.embed_positions.forward(pose_tokens) + ) + encoder_state = self.layernorm_embedding.forward(encoder_state) + for layer in self.layers: + encoder_state = layer.forward(encoder_state, attention_mask) + encoder_state = self.final_ln.forward(encoder_state) + return encoder_state \ No newline at end of file diff --git a/models/vqgan_detokenizer.py b/models/vqgan_detokenizer.py new file mode 100644 index 0000000..a56a1cb --- /dev/null +++ b/models/vqgan_detokenizer.py @@ -0,0 +1,175 @@ +import torch +from torch import Tensor +from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding + +batch_size: int = 1 + +class ResnetBlock(Module): + def __init__(self, log2_count_in: int, log2_count_out: int): + super().__init__() + m, n = 2 ** log2_count_in, 2 ** log2_count_out + self.is_middle = m == n + self.norm1 = GroupNorm(2 ** 5, m) + self.conv1 = Conv2d(m, n, 3, padding=1) + self.norm2 = GroupNorm(2 ** 5, n) + self.conv2 = Conv2d(n, n, 3, padding=1) + if not self.is_middle: + self.nin_shortcut = Conv2d(m, n, 1) + + def forward(self, x: Tensor) -> Tensor: + h = x + h = self.norm1.forward(h) + h *= torch.sigmoid(h) + h = self.conv1.forward(h) + h = self.norm2.forward(h) + h *= torch.sigmoid(h) + h = self.conv2(h) + if not self.is_middle: + x = self.nin_shortcut.forward(x) + return x + h + + +class AttentionBlock(Module): + def __init__(self): + super().__init__() + n = 2 ** 9 + self.norm = GroupNorm(2 ** 5, n) + self.q = Conv2d(n, n, 1) + self.k = Conv2d(n, n, 1) + self.v = Conv2d(n, n, 1) + self.proj_out = Conv2d(n, n, 1) + + def forward(self, x: Tensor) -> Tensor: + n = 2 ** 9 + h = x + h = self.norm(h) + q = self.q.forward(h) + k = self.k.forward(h) + v = self.v.forward(h) + q = q.reshape(batch_size, n, 2 ** 8) + q = q.permute(0, 2, 1) + k = k.reshape(batch_size, n, 2 ** 8) + w = torch.bmm(q, k) + w /= n ** 0.5 + w = torch.softmax(w, dim=2) + v = v.reshape(batch_size, n, 2 ** 8) + w = w.permute(0, 2, 1) + h = torch.bmm(v, w) + h = h.reshape(batch_size, n, 2 ** 4, 2 ** 4) + h = self.proj_out.forward(h) + return x + h + +class MiddleLayer(Module): + def __init__(self): + super().__init__() + self.block_1 = ResnetBlock(9, 9) + self.attn_1 = AttentionBlock() + self.block_2 = ResnetBlock(9, 9) + + def forward(self, h: Tensor) -> Tensor: + h = self.block_1.forward(h) + h = self.attn_1.forward(h) + h = self.block_2.forward(h) + return h + +class Upsample(Module): + def __init__(self, log2_count): + super().__init__() + n = 2 ** log2_count + self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2) + self.conv = Conv2d(n, n, 3, padding=1) + + def forward(self, x: Tensor) -> Tensor: + x = self.upsample.forward(x) + x = self.conv.forward(x) + return x + +class UpsampleBlock(Module): + def __init__( + self, + log2_count_in: int, + log2_count_out: int, + has_attention: bool, + has_upsample: bool + ): + super().__init__() + self.has_attention = has_attention + self.has_upsample = has_upsample + self.block = ModuleList([ + ResnetBlock(log2_count_in, log2_count_out), + ResnetBlock(log2_count_out, log2_count_out), + ResnetBlock(log2_count_out, log2_count_out) + ]) + if has_attention: + self.attn = ModuleList([ + AttentionBlock(), + AttentionBlock(), + AttentionBlock() + ]) + else: + self.attn = ModuleList() + + if has_upsample: + self.upsample = Upsample(log2_count_out) + + + def forward(self, h: Tensor) -> Tensor: + for j in range(3): + h = self.block[j].forward(h) + if self.has_attention: + h = self.attn[j].forward(h) + if self.has_upsample: + h = self.upsample.forward(h) + return h + +class Decoder(Module): + def __init__(self): + super().__init__() + + self.conv_in = Conv2d(2 ** 8, 2 ** 9, 3, padding=1) + self.mid = MiddleLayer() + + self.up = ModuleList([ + UpsampleBlock(7, 7, False, False), + UpsampleBlock(8, 7, False, True), + UpsampleBlock(8, 8, False, True), + UpsampleBlock(9, 8, False, True), + UpsampleBlock(9, 9, True, True) + ]) + + self.norm_out = GroupNorm(2 ** 5, 2 ** 7) + self.conv_out = Conv2d(2 ** 7, 3, 3, padding=1) + + def forward(self, z: Tensor) -> Tensor: + z = self.conv_in.forward(z) + z = self.mid.forward(z) + + for i in reversed(range(5)): + z = self.up[i].forward(z) + + z = self.norm_out.forward(z) + z *= torch.sigmoid(z) + z = self.conv_out.forward(z) + return z + +class VQGanDetokenizer(Module): + def __init__(self): + super().__init__() + m, n = 2 ** 14, 2 ** 8 + self.embedding = Embedding(m, n) + self.post_quant_conv = Conv2d(n, n, 1) + self.decoder = Decoder() + + def forward(self, z: Tensor) -> Tensor: + z = self.embedding.forward(z) + z = z.view((batch_size, 2 ** 4, 2 ** 4, 2 ** 8)) + z = z.permute(0, 3, 1, 2).contiguous() + z = self.post_quant_conv.forward(z) + z = self.decoder.forward(z) + z = z.permute(0, 2, 3, 1) + # z = torch.concat(( + # torch.concat((z[0], z[1]), axis=1), + # torch.concat((z[2], z[3]), axis=1) + # ), axis=0) + z = z.clip(0.0, 1.0) * 255 + return z[0] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..165595e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +torch +flax \ No newline at end of file diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000..ee82889 --- /dev/null +++ b/setup.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +pip install -r requirements.txt + +mkdir -p pretrained + +# download vqgan +git lfs install +git clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384 ./pretrained/vqgan + +# download dalle-mini and dalle mega +pip install wandb +wandb login +wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0 +wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14 \ No newline at end of file diff --git a/text_tokenizer.py b/text_tokenizer.py new file mode 100644 index 0000000..c9afd70 --- /dev/null +++ b/text_tokenizer.py @@ -0,0 +1,39 @@ +from math import inf +from typing import List, Tuple + +class TextTokenizer: + def __init__(self, vocab: dict, merges: List[str]): + self.token_from_subword = vocab + pairs = [tuple(pair.split()) for pair in merges] + self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) + + def __call__(self, text: str) -> List[int]: + sep_token = self.token_from_subword[''] + cls_token = self.token_from_subword[''] + unk_token = self.token_from_subword[''] + text = text.lower().encode("ascii", errors="ignore").decode() + tokens = [ + self.token_from_subword.get(subword, unk_token) + for word in text.split(" ") if len(word) > 0 + for subword in self.get_byte_pair_encoding(word) + ] + return [cls_token] + tokens + [sep_token] + + def get_byte_pair_encoding(self, word: str) -> List[str]: + def get_pair_rank(pair: Tuple[str, str]) -> int: + return self.rank_from_pair.get(pair, inf) + + subwords = [chr(ord(" ") + 256)] + list(word) + while len(subwords) > 1: + pairs = list(zip(subwords[:-1], subwords[1:])) + pair_to_merge = min(pairs, key=get_pair_rank) + if pair_to_merge not in self.rank_from_pair: break + i = pairs.index(pair_to_merge) + subwords = ( + (subwords[:i] if i > 0 else []) + + [subwords[i] + subwords[i + 1]] + + (subwords[i + 2:] if i + 2 < len(subwords) else []) + ) + + print(subwords) + return subwords \ No newline at end of file