From 0d9998926d59e0dee144874317e0a490a45db8d1 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Mon, 4 Jul 2022 16:06:49 -0400 Subject: [PATCH] display intermediate images --- README.md | 2 +- README.rst | 6 +- image_from_text.py | 10 +-- min_dalle/min_dalle.py | 107 +++++++++++++++++-------- min_dalle/models/dalle_bart_decoder.py | 54 +++++-------- min_dalle/text_tokenizer.py | 11 ++- setup.py | 2 +- 7 files changed, 107 insertions(+), 85 deletions(-) diff --git a/README.md b/README.md index 14222ee..3901994 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ This is a fast, minimal implementation of Boris Dayma's [DALL·E Mega](https://github.com/borisdayma/dalle-mini). It has been stripped down for inference and converted to PyTorch. The only third party dependencies are numpy, requests, pillow and torch. -To generate a 4x4 grid of DALL·E Mega images it takes +To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a T4 in Colab - 48 sec with a P100 in Colab - 14 sec with an A100 on Replicate diff --git a/README.rst b/README.rst index 726495c..6dae41c 100644 --- a/README.rst +++ b/README.rst @@ -8,9 +8,9 @@ Mega `__. It has been stripped down for inference and converted to PyTorch. The only third party dependencies are numpy, requests, pillow and torch. -It takes - **35 seconds** to generate a 3x3 grid with a P100 in Colab - -**16 seconds** to generate a 4x4 grid with an A100 on Replicate - -**TBD** to generate a 4x4 grid with an H100 (@NVIDIA?) +To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a +T4 in Colab - 48 sec with a P100 in Colab - 14 sec with an A100 on +Replicate - TBD with an H100 (@NVIDIA?) The flax model and code for converting it to torch can be found `here `__. diff --git a/image_from_text.py b/image_from_text.py index 77429f8..582f24e 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -1,7 +1,6 @@ import argparse import os from PIL import Image - from min_dalle import MinDalle @@ -9,7 +8,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--mega', action='store_true') parser.add_argument('--no-mega', dest='mega', action='store_false') parser.set_defaults(mega=False) -parser.add_argument('--text', type=str, default='alien life') +parser.add_argument('--text', type=str, default='Dali painting of WALL·E') parser.add_argument('--seed', type=int, default=-1) parser.add_argument('--grid-size', type=int, default=1) parser.add_argument('--image-path', type=str, default='generated') @@ -17,7 +16,7 @@ parser.add_argument('--models-root', type=str, default='pretrained') parser.add_argument('--row-count', type=int, default=16) # for debugging -def ascii_from_image(image: Image.Image, size: int) -> str: +def ascii_from_image(image: Image.Image, size: int = 128) -> str: rgb_pixels = image.resize((size, int(0.55 * size))).convert('L').getdata() chars = list('.,;/IOX') chars = [chars[i * len(chars) // 256] for i in rgb_pixels] @@ -57,12 +56,13 @@ def generate_image( text, seed, grid_size ** 2, - row_count + row_count, + is_verbose=True ) image_tokens = image_tokens[:, :token_count].to('cpu').detach().numpy() print('image tokens', image_tokens) else: - image = model.generate_image(text, seed, grid_size) + image = model.generate_image(text, seed, grid_size, is_verbose=True) save_image(image, image_path) print(ascii_from_image(image, size=128)) diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index f156f5c..ab76f2d 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -1,10 +1,11 @@ import os from PIL import Image import numpy -from torch import LongTensor +from torch import LongTensor, FloatTensor import torch import json import requests +from typing import Callable, Tuple torch.set_grad_enabled(False) torch.set_num_threads(os.cpu_count()) @@ -26,7 +27,6 @@ class MinDalle: self.is_reusable = is_reusable self.is_verbose = is_verbose self.text_token_count = 64 - self.image_token_count = 256 self.layer_count = 24 if is_mega else 12 self.attention_head_count = 32 if is_mega else 16 self.embed_count = 2048 if is_mega else 1024 @@ -91,7 +91,7 @@ class MinDalle: vocab = json.load(f) with open(self.merges_path, 'r', encoding='utf8') as f: merges = f.read().split("\n")[1:-1] - self.tokenizer = TextTokenizer(vocab, merges, is_verbose=self.is_verbose) + self.tokenizer = TextTokenizer(vocab, merges) def init_encoder(self): @@ -117,7 +117,6 @@ class MinDalle: if not is_downloaded: self.download_decoder() if self.is_verbose: print("initializing DalleBartDecoder") self.decoder = DalleBartDecoder( - image_token_count = self.image_token_count, image_vocab_count = self.image_vocab_count, attention_head_count = self.attention_head_count, embed_count = self.embed_count, @@ -142,16 +141,37 @@ class MinDalle: if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() + def image_from_tokens( + self, + grid_size: int, + image_tokens: LongTensor, + is_verbose: bool = False + ) -> Image.Image: + if not self.is_reusable: del self.decoder + if torch.cuda.is_available(): torch.cuda.empty_cache() + if not self.is_reusable: self.init_detokenizer() + if is_verbose: print("detokenizing image") + images = self.detokenizer.forward(image_tokens).to(torch.uint8) + if not self.is_reusable: del self.detokenizer + images = images.reshape([grid_size] * 2 + list(images.shape[1:])) + image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2) + image = Image.fromarray(image.to('cpu').detach().numpy()) + return image + + def generate_image_tokens( self, text: str, seed: int, - image_count: int, - row_count: int + grid_size: int, + row_count: int, + mid_count: int = None, + handle_intermediate_image: Callable[[int, Image.Image], None] = None, + is_verbose: bool = False ) -> LongTensor: - if self.is_verbose: print("tokenizing text") - tokens = self.tokenizer.tokenize(text) - if self.is_verbose: print("text tokens", tokens) + if is_verbose: print("tokenizing text") + tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) + if is_verbose: print("text tokens", tokens) text_tokens = numpy.ones((2, 64), dtype=numpy.int32) text_tokens[0, :2] = [tokens[0], tokens[-1]] text_tokens[1, :len(tokens)] = tokens @@ -160,40 +180,57 @@ class MinDalle: if torch.cuda.is_available(): text_tokens = text_tokens.cuda() if not self.is_reusable: self.init_encoder() - if self.is_verbose: print("encoding text tokens") + if is_verbose: print("encoding text tokens") encoder_state = self.encoder.forward(text_tokens) if not self.is_reusable: del self.encoder if torch.cuda.is_available(): torch.cuda.empty_cache() if not self.is_reusable: self.init_decoder() - if self.is_verbose: print("sampling image tokens") - if seed > 0: torch.manual_seed(seed) - image_tokens = self.decoder.forward( - image_count, - row_count, - text_tokens, - encoder_state + + encoder_state, attention_mask, attention_state, image_tokens = ( + self.decoder.decode_initial( + seed, + grid_size ** 2, + text_tokens, + encoder_state + ) ) - if not self.is_reusable: del self.decoder - return image_tokens - + + for row_index in range(row_count): + if is_verbose: + print('sampling row {} of {}'.format(row_index + 1, row_count)) + attention_state, image_tokens = self.decoder.decode_row( + row_index, + encoder_state, + attention_mask, + attention_state, + image_tokens + ) + if mid_count is not None: + if ((row_index + 1) * mid_count) % row_count == 0: + tokens = image_tokens[:, 1:] + image = self.image_from_tokens(grid_size, tokens, is_verbose) + handle_intermediate_image(row_index, image) + + return image_tokens[:, 1:] + def generate_image( self, - text: str, + text: str, seed: int = -1, - grid_size: int = 1 + grid_size: int = 1, + mid_count: int = None, + handle_intermediate_image: Callable[[Image.Image], None] = None, + is_verbose: bool = False ) -> Image.Image: - image_count = grid_size ** 2 - row_count = 16 - image_tokens = self.generate_image_tokens(text, seed, image_count, row_count) - if torch.cuda.is_available(): torch.cuda.empty_cache() - if not self.is_reusable: self.init_detokenizer() - if self.is_verbose: print("detokenizing image") - images = self.detokenizer.forward(image_tokens).to(torch.uint8) - if not self.is_reusable: del self.detokenizer - images = images.reshape([grid_size] * 2 + list(images.shape[1:])) - image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2) - image = Image.fromarray(image.to('cpu').detach().numpy()) - if torch.cuda.is_available(): torch.cuda.empty_cache() - return image \ No newline at end of file + image_tokens = self.generate_image_tokens( + text, + seed, + grid_size, + row_count = 16, + mid_count = mid_count, + handle_intermediate_image = handle_intermediate_image, + is_verbose = is_verbose + ) + return self.image_from_tokens(grid_size, image_tokens, is_verbose) \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index cf82704..dcbc680 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -5,6 +5,9 @@ torch.set_grad_enabled(False) from .dalle_bart_encoder import GLU, AttentionBase +IMAGE_TOKEN_COUNT = 256 +BLANK_TOKEN = 6965 + class DecoderCrossAttention(AttentionBase): def forward( @@ -20,9 +23,9 @@ class DecoderCrossAttention(AttentionBase): class DecoderSelfAttention(AttentionBase): - def __init__(self, head_count: int, embed_count: int, token_count: int): + def __init__(self, head_count: int, embed_count: int): super().__init__(head_count, embed_count) - token_indices = torch.arange(token_count) + token_indices = torch.arange(IMAGE_TOKEN_COUNT) if torch.cuda.is_available(): token_indices = token_indices.cuda() self.token_indices = token_indices @@ -48,19 +51,13 @@ class DecoderSelfAttention(AttentionBase): class DecoderLayer(nn.Module): def __init__( self, - image_token_count: int, head_count: int, embed_count: int, glu_embed_count: int ): super().__init__() - self.image_token_count = image_token_count self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) - self.self_attn = DecoderSelfAttention( - head_count, - embed_count, - image_token_count - ) + self.self_attn = DecoderSelfAttention(head_count, embed_count) self.self_attn_layer_norm = nn.LayerNorm(embed_count) self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count) self.encoder_attn = DecoderCrossAttention(head_count, embed_count) @@ -110,7 +107,6 @@ class DalleBartDecoder(nn.Module): def __init__( self, image_vocab_count: int, - image_token_count: int, embed_count: int, attention_head_count: int, glu_embed_count: int, @@ -121,12 +117,10 @@ class DalleBartDecoder(nn.Module): self.layer_count = layer_count self.embed_count = embed_count self.condition_factor = 10.0 - self.image_token_count = image_token_count self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) - self.embed_positions = nn.Embedding(image_token_count, embed_count) + self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count) self.layers: List[DecoderLayer] = nn.ModuleList([ DecoderLayer( - image_token_count, attention_head_count, embed_count, glu_embed_count @@ -137,7 +131,7 @@ class DalleBartDecoder(nn.Module): self.final_ln = nn.LayerNorm(embed_count) self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False) self.zero_prob = torch.zeros([1]) - self.token_indices = torch.arange(self.image_token_count) + self.token_indices = torch.arange(IMAGE_TOKEN_COUNT) self.start_token = torch.tensor([start_token]).to(torch.long) if torch.cuda.is_available(): self.zero_prob = self.zero_prob.cuda() @@ -183,13 +177,13 @@ class DalleBartDecoder(nn.Module): torch.exp(logits - top_logits[:, [0]]) ) return probs, attention_state - + def decode_row( self, row_index: int, - attention_mask: BoolTensor, encoder_state: FloatTensor, + attention_mask: BoolTensor, attention_state: FloatTensor, image_tokens_sequence: LongTensor ) -> Tuple[FloatTensor, LongTensor]: @@ -202,19 +196,18 @@ class DalleBartDecoder(nn.Module): prev_tokens = image_tokens_sequence[:, i], token_index = self.token_indices[[i]] ) - image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0] return attention_state, image_tokens_sequence - - def forward( + + def decode_initial( self, + seed: int, image_count: int, - row_count: int, text_tokens: LongTensor, encoder_state: FloatTensor - ) -> LongTensor: + ) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]: expanded_indices = [0] * image_count + [1] * image_count text_tokens = text_tokens[expanded_indices] encoder_state = encoder_state[expanded_indices] @@ -223,13 +216,13 @@ class DalleBartDecoder(nn.Module): attention_state_shape = ( self.layer_count, image_count * 4, - self.image_token_count, + IMAGE_TOKEN_COUNT, self.embed_count ) attention_state = torch.zeros(attention_state_shape) image_tokens_sequence = torch.full( - (image_count, self.image_token_count + 1), - 6965, # black token + (image_count, IMAGE_TOKEN_COUNT + 1), + BLANK_TOKEN, dtype=torch.long ) if torch.cuda.is_available(): @@ -238,13 +231,6 @@ class DalleBartDecoder(nn.Module): image_tokens_sequence[:, 0] = self.start_token[0] - for row_index in range(row_count): - attention_state, image_tokens_sequence = self.decode_row( - row_index, - attention_mask, - encoder_state, - attention_state, - image_tokens_sequence - ) - - return image_tokens_sequence[:, 1:] \ No newline at end of file + if seed > 0: torch.manual_seed(seed) + + return encoder_state, attention_mask, attention_state, image_tokens_sequence \ No newline at end of file diff --git a/min_dalle/text_tokenizer.py b/min_dalle/text_tokenizer.py index f2f201a..31114d6 100644 --- a/min_dalle/text_tokenizer.py +++ b/min_dalle/text_tokenizer.py @@ -2,13 +2,12 @@ from math import inf from typing import List, Tuple class TextTokenizer: - def __init__(self, vocab: dict, merges: List[str], is_verbose: bool = True): - self.is_verbose = is_verbose + def __init__(self, vocab: dict, merges: List[str]): self.token_from_subword = vocab pairs = [tuple(pair.split()) for pair in merges] self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) - def tokenize(self, text: str) -> List[int]: + def tokenize(self, text: str, is_verbose: bool = False) -> List[int]: sep_token = self.token_from_subword[''] cls_token = self.token_from_subword[''] unk_token = self.token_from_subword[''] @@ -16,11 +15,11 @@ class TextTokenizer: tokens = [ self.token_from_subword.get(subword, unk_token) for word in text.split(" ") if len(word) > 0 - for subword in self.get_byte_pair_encoding(word) + for subword in self.get_byte_pair_encoding(word, is_verbose) ] return [cls_token] + tokens + [sep_token] - def get_byte_pair_encoding(self, word: str) -> List[str]: + def get_byte_pair_encoding(self, word: str, is_verbose: bool) -> List[str]: def get_pair_rank(pair: Tuple[str, str]) -> int: return self.rank_from_pair.get(pair, inf) @@ -36,5 +35,5 @@ class TextTokenizer: (subwords[i + 2:] if i + 2 < len(subwords) else []) ) - if self.is_verbose: print(subwords) + if is_verbose: print(subwords) return subwords \ No newline at end of file diff --git a/setup.py b/setup.py index fdbaa98..7757bc8 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALL·E)', long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.2.17', + version='0.2.21', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',