diff --git a/image_from_text.py b/image_from_text.py index 9d151e2..9008253 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -11,9 +11,10 @@ parser.add_argument('--no-mega', dest='mega', action='store_false') parser.set_defaults(mega=False) parser.add_argument('--text', type=str, default='alien life') parser.add_argument('--seed', type=int, default=-1) -parser.add_argument('--image_path', type=str, default='generated') -parser.add_argument('--models_root', type=str, default='pretrained') -parser.add_argument('--token_count', type=int, default=256) # for debugging +parser.add_argument('--grid-size', type=int, default=1) +parser.add_argument('--image-path', type=str, default='generated') +parser.add_argument('--models-root', type=str, default='pretrained') +parser.add_argument('--token-count', type=int, default=256) # for debugging def ascii_from_image(image: Image.Image, size: int) -> str: @@ -38,6 +39,7 @@ def generate_image( is_mega: bool, text: str, seed: int, + grid_size: int, image_path: str, models_root: str, token_count: int @@ -51,10 +53,10 @@ def generate_image( ) if token_count < 256: - image_tokens = model.generate_image_tokens(text, seed) - print('image tokens', list(image_tokens.to('cpu').detach().numpy())) + image_tokens = model.generate_image_tokens(text, seed, grid_size ** 2) + print('image tokens', image_tokens.to('cpu').detach().numpy()) else: - image = model.generate_image(text, seed) + image = model.generate_image(text, seed, grid_size) save_image(image, image_path) print(ascii_from_image(image, size=128)) @@ -66,6 +68,7 @@ if __name__ == '__main__': is_mega=args.mega, text=args.text, seed=args.seed, + grid_size=args.grid_size, image_path=args.image_path, models_root=args.models_root, token_count=args.token_count diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index 37d0c6b..f7bb6ef 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -1,4 +1,5 @@ import os +from re import I from PIL import Image import numpy from torch import LongTensor @@ -28,7 +29,6 @@ class MinDalle: self.is_reusable = is_reusable self.is_verbose = is_verbose self.sample_token_count = sample_token_count - self.batch_count = 2 self.text_token_count = 64 self.image_token_count = 256 self.layer_count = 24 if is_mega else 12 @@ -128,8 +128,7 @@ class MinDalle: embed_count = self.embed_count, glu_embed_count = self.glu_embed_count, layer_count = self.layer_count, - start_token = self.image_vocab_count, - batch_count = self.batch_count + start_token = self.image_vocab_count ) params = torch.load(self.decoder_params_path) self.decoder.load_state_dict(params, strict=False) @@ -148,7 +147,12 @@ class MinDalle: if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() - def generate_image_tokens(self, text: str, seed: int) -> LongTensor: + def generate_image_tokens( + self, + text: str, + seed: int, + image_count: int + ) -> LongTensor: if self.is_verbose: print("tokenizing text") tokens = self.tokenizer.tokenize(text) if self.is_verbose: print("text tokens", tokens) @@ -166,18 +170,29 @@ class MinDalle: if not self.is_reusable: self.init_decoder() if self.is_verbose: print("sampling image tokens") - if seed < 0: seed = random.randint(0, 2 ** 31) - torch.manual_seed(seed) - image_tokens = self.decoder.forward(text_tokens, encoder_state) + if seed > 0: torch.manual_seed(seed) + image_tokens = self.decoder.forward( + image_count, + text_tokens, + encoder_state + ) if not self.is_reusable: del self.decoder return image_tokens - def generate_image(self, text: str, seed: int) -> Image.Image: - image_tokens = self.generate_image_tokens(text, seed) + def generate_image( + self, + text: str, + seed: int = -1, + grid_size: int = 1 + ) -> Image.Image: + image_count = grid_size ** 2 + image_tokens = self.generate_image_tokens(text, seed, image_count) if not self.is_reusable: self.init_detokenizer() if self.is_verbose: print("detokenizing image") - image = self.detokenizer.forward(image_tokens).to(torch.uint8) + images = self.detokenizer.forward(image_tokens).to(torch.uint8) if not self.is_reusable: del self.detokenizer + images = images.reshape([grid_size] * 2 + list(images.shape[1:])) + image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2) image = Image.fromarray(image.to('cpu').detach().numpy()) return image \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index 4b71858..ffbf51c 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -1,4 +1,3 @@ -from typing import List, Tuple import torch from torch import LongTensor, nn, FloatTensor, BoolTensor torch.set_grad_enabled(False) @@ -26,7 +25,7 @@ class DecoderSelfAttention(AttentionBase): attention_state: FloatTensor, attention_mask: BoolTensor, token_mask: BoolTensor - ) -> Tuple[FloatTensor, FloatTensor]: + ) -> tuple[FloatTensor, FloatTensor]: keys = self.k_proj.forward(decoder_state) values = self.v_proj.forward(decoder_state) queries = self.q_proj.forward(decoder_state) @@ -71,13 +70,13 @@ class DecoderLayer(nn.Module): attention_state: FloatTensor, attention_mask: BoolTensor, token_index: LongTensor - ) -> Tuple[FloatTensor, FloatTensor]: + ) -> tuple[FloatTensor, FloatTensor]: # Self Attention residual = decoder_state decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) self_attn_mask = self.token_indices < token_index + 1 + self_attn_mask = self_attn_mask[None][[0] * decoder_state.shape[0]] token_mask = self.token_indices == token_index - self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0]) decoder_state, attention_state = self.self_attn.forward( decoder_state, attention_state, @@ -116,17 +115,17 @@ class DalleBartDecoder(nn.Module): attention_head_count: int, glu_embed_count: int, layer_count: int, - batch_count: int, start_token: int ): super().__init__() self.layer_count = layer_count + self.embed_count = embed_count self.sample_token_count = sample_token_count self.condition_factor = 10.0 self.image_token_count = image_token_count self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count) - self.layers: List[DecoderLayer] = nn.ModuleList([ + self.layers: list[DecoderLayer] = nn.ModuleList([ DecoderLayer( image_token_count, attention_head_count, @@ -138,12 +137,6 @@ class DalleBartDecoder(nn.Module): self.layernorm_embedding = nn.LayerNorm(embed_count) self.final_ln = nn.LayerNorm(embed_count) self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False) - self.attention_state_shape = ( - layer_count, - 2 * batch_count, - image_token_count, - embed_count - ) self.zero_prob = torch.zeros([1]) self.token_indices = torch.arange(self.sample_token_count) self.start_token = torch.tensor([start_token]).to(torch.long) @@ -155,17 +148,16 @@ class DalleBartDecoder(nn.Module): def decode_step( self, - text_tokens: LongTensor, + attention_mask: BoolTensor, encoder_state: FloatTensor, attention_state: FloatTensor, - prev_token: LongTensor, + prev_tokens: LongTensor, token_index: LongTensor - ) -> Tuple[LongTensor, FloatTensor]: - attention_mask = text_tokens.not_equal(1) - batch_count = encoder_state.shape[0] - prev_token_batched = torch.cat([prev_token] * batch_count) - token_index_batched = torch.cat([token_index] * batch_count) - decoder_state = self.embed_tokens.forward(prev_token_batched) + ) -> tuple[LongTensor, FloatTensor]: + image_count = encoder_state.shape[0] // 2 + token_index_batched = token_index[[0] * image_count * 2] + prev_tokens = prev_tokens[list(range(image_count)) * 2] + decoder_state = self.embed_tokens.forward(prev_tokens) decoder_state += self.embed_positions.forward(token_index_batched) decoder_state = self.layernorm_embedding.forward(decoder_state) decoder_state = decoder_state[:, None] @@ -182,38 +174,52 @@ class DalleBartDecoder(nn.Module): decoder_state = self.final_ln(decoder_state) logits = self.lm_head(decoder_state) a = self.condition_factor - logits: FloatTensor = (1 - a) * logits[0, -1] + a * logits[1, -1] + logits: FloatTensor = ( + logits[:image_count, -1] * (1 - a) + + logits[image_count:, -1] * a + ) top_logits, _ = logits.topk(50, dim=-1) probs = torch.where( - logits < top_logits[-1], + logits < top_logits[:, [-1]], self.zero_prob, - torch.exp(logits - top_logits[0]) + torch.exp(logits - top_logits[:, [0]]) ) return probs, torch.stack(attention_states_new) def forward( self, + image_count: int, text_tokens: LongTensor, encoder_state: FloatTensor ) -> LongTensor: - image_tokens: List[LongTensor] = [] - attention_state = torch.zeros(self.attention_state_shape) - if torch.cuda.is_available(): - attention_state = attention_state.cuda() - image_token = self.start_token + expanded_indices = [0] * image_count + [1] * image_count + text_tokens = text_tokens[expanded_indices] + encoder_state = encoder_state[expanded_indices] + attention_mask = text_tokens.not_equal(1) + attention_state_shape = ( + self.layer_count, + image_count * 4, + self.image_token_count, + self.embed_count + ) + attention_state = torch.zeros(attention_state_shape) + if torch.cuda.is_available(): attention_state = attention_state.cuda() + + image_tokens = self.start_token[[0] * image_count] + image_tokens_sequence: list[LongTensor] = [] for i in range(self.sample_token_count): probs, attention_state = self.decode_step( - text_tokens = text_tokens, + attention_mask = attention_mask, encoder_state = encoder_state, attention_state = attention_state, - prev_token = image_token, + prev_tokens = image_tokens, token_index = self.token_indices[[i]] ) - image_token = torch.multinomial(probs, 1) - image_tokens += [image_token] - - return torch.cat(image_tokens) \ No newline at end of file + image_tokens = torch.multinomial(probs, 1)[:, 0] + image_tokens_sequence += [image_tokens] + + return torch.stack(image_tokens_sequence).T \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_encoder.py b/min_dalle/models/dalle_bart_encoder.py index e3d8eb8..7d18a65 100644 --- a/min_dalle/models/dalle_bart_encoder.py +++ b/min_dalle/models/dalle_bart_encoder.py @@ -1,4 +1,3 @@ -from typing import List import torch from torch import nn, BoolTensor, FloatTensor, LongTensor torch.set_grad_enabled(False) @@ -121,7 +120,7 @@ class DalleBartEncoder(nn.Module): super().__init__() self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) self.embed_positions = nn.Embedding(text_token_count, embed_count) - self.layers: List[EncoderLayer] = nn.ModuleList([ + self.layers: list[EncoderLayer] = nn.ModuleList([ EncoderLayer( embed_count = embed_count, head_count = attention_head_count, @@ -137,8 +136,7 @@ class DalleBartEncoder(nn.Module): def forward(self, text_tokens: LongTensor) -> FloatTensor: attention_mask = text_tokens.not_equal(1) - batch_count = text_tokens.shape[0] - pose_tokens = torch.stack([self.token_indices] * batch_count) + pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]] encoder_state = ( self.embed_tokens.forward(text_tokens) + self.embed_positions.forward(pose_tokens) diff --git a/min_dalle/models/vqgan_detokenizer.py b/min_dalle/models/vqgan_detokenizer.py index 1233046..43ebc34 100644 --- a/min_dalle/models/vqgan_detokenizer.py +++ b/min_dalle/models/vqgan_detokenizer.py @@ -3,8 +3,6 @@ from torch import Tensor from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding torch.set_grad_enabled(False) -BATCH_COUNT: int = 1 - class ResnetBlock(Module): def __init__(self, log2_count_in: int, log2_count_out: int): @@ -42,22 +40,22 @@ class AttentionBlock(Module): self.proj_out = Conv2d(n, n, 1) def forward(self, x: Tensor) -> Tensor: - n = 2 ** 9 + n, m = 2 ** 9, x.shape[0] h = x h = self.norm(h) q = self.q.forward(h) k = self.k.forward(h) v = self.v.forward(h) - q = q.reshape(BATCH_COUNT, n, 2 ** 8) + q = q.reshape(m, n, 2 ** 8) q = q.permute(0, 2, 1) - k = k.reshape(BATCH_COUNT, n, 2 ** 8) + k = k.reshape(m, n, 2 ** 8) w = torch.bmm(q, k) w /= n ** 0.5 w = torch.softmax(w, dim=2) - v = v.reshape(BATCH_COUNT, n, 2 ** 8) + v = v.reshape(m, n, 2 ** 8) w = w.permute(0, 2, 1) h = torch.bmm(v, w) - h = h.reshape(BATCH_COUNT, n, 2 ** 4, 2 ** 4) + h = h.reshape(m, n, 2 ** 4, 2 ** 4) h = self.proj_out.forward(h) return x + h @@ -169,10 +167,10 @@ class VQGanDetokenizer(Module): def forward(self, z: Tensor) -> Tensor: z = self.embedding.forward(z) - z = z.view((BATCH_COUNT, 2 ** 4, 2 ** 4, 2 ** 8)) + z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8)) z = z.permute(0, 3, 1, 2).contiguous() z = self.post_quant_conv.forward(z) z = self.decoder.forward(z) z = z.permute(0, 2, 3, 1) z = z.clip(0.0, 1.0) * 255 - return z[0] + return z diff --git a/min_dalle/text_tokenizer.py b/min_dalle/text_tokenizer.py index 01d2111..0cc0a55 100644 --- a/min_dalle/text_tokenizer.py +++ b/min_dalle/text_tokenizer.py @@ -1,15 +1,13 @@ from math import inf -from typing import List, Tuple - class TextTokenizer: - def __init__(self, vocab: dict, merges: List[str], is_verbose: bool = True): + def __init__(self, vocab: dict, merges: list[str], is_verbose: bool = True): self.is_verbose = is_verbose self.token_from_subword = vocab pairs = [tuple(pair.split()) for pair in merges] self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) - def tokenize(self, text: str) -> List[int]: + def tokenize(self, text: str) -> list[int]: sep_token = self.token_from_subword[''] cls_token = self.token_from_subword[''] unk_token = self.token_from_subword[''] @@ -21,8 +19,8 @@ class TextTokenizer: ] return [cls_token] + tokens + [sep_token] - def get_byte_pair_encoding(self, word: str) -> List[str]: - def get_pair_rank(pair: Tuple[str, str]) -> int: + def get_byte_pair_encoding(self, word: str) -> list[str]: + def get_pair_rank(pair: tuple[str, str]) -> int: return self.rank_from_pair.get(pair, inf) subwords = [chr(ord(" ") + 256)] + list(word)