display intermediate images

This commit is contained in:
Brett Kuprel 2022-07-04 16:06:49 -04:00
parent b634375edf
commit 0d9998926d
7 changed files with 107 additions and 85 deletions

2
README.md vendored
View File

@ -8,7 +8,7 @@
This is a fast, minimal implementation of Boris Dayma's [DALL·E Mega](https://github.com/borisdayma/dalle-mini). It has been stripped down for inference and converted to PyTorch. The only third party dependencies are numpy, requests, pillow and torch. This is a fast, minimal implementation of Boris Dayma's [DALL·E Mega](https://github.com/borisdayma/dalle-mini). It has been stripped down for inference and converted to PyTorch. The only third party dependencies are numpy, requests, pillow and torch.
To generate a 4x4 grid of DALL·E Mega images it takes To generate a 4x4 grid of DALL·E Mega images it takes:
- 89 sec with a T4 in Colab - 89 sec with a T4 in Colab
- 48 sec with a P100 in Colab - 48 sec with a P100 in Colab
- 14 sec with an A100 on Replicate - 14 sec with an A100 on Replicate

6
README.rst vendored
View File

@ -8,9 +8,9 @@ Mega <https://github.com/borisdayma/dalle-mini>`__. It has been stripped
down for inference and converted to PyTorch. The only third party down for inference and converted to PyTorch. The only third party
dependencies are numpy, requests, pillow and torch. dependencies are numpy, requests, pillow and torch.
It takes - **35 seconds** to generate a 3x3 grid with a P100 in Colab - To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a
**16 seconds** to generate a 4x4 grid with an A100 on Replicate - T4 in Colab - 48 sec with a P100 in Colab - 14 sec with an A100 on
**TBD** to generate a 4x4 grid with an H100 (@NVIDIA?) Replicate - TBD with an H100 (@NVIDIA?)
The flax model and code for converting it to torch can be found The flax model and code for converting it to torch can be found
`here <https://github.com/kuprel/min-dalle-flax>`__. `here <https://github.com/kuprel/min-dalle-flax>`__.

View File

@ -1,7 +1,6 @@
import argparse import argparse
import os import os
from PIL import Image from PIL import Image
from min_dalle import MinDalle from min_dalle import MinDalle
@ -9,7 +8,7 @@ parser = argparse.ArgumentParser()
parser.add_argument('--mega', action='store_true') parser.add_argument('--mega', action='store_true')
parser.add_argument('--no-mega', dest='mega', action='store_false') parser.add_argument('--no-mega', dest='mega', action='store_false')
parser.set_defaults(mega=False) parser.set_defaults(mega=False)
parser.add_argument('--text', type=str, default='alien life') parser.add_argument('--text', type=str, default='Dali painting of WALL·E')
parser.add_argument('--seed', type=int, default=-1) parser.add_argument('--seed', type=int, default=-1)
parser.add_argument('--grid-size', type=int, default=1) parser.add_argument('--grid-size', type=int, default=1)
parser.add_argument('--image-path', type=str, default='generated') parser.add_argument('--image-path', type=str, default='generated')
@ -17,7 +16,7 @@ parser.add_argument('--models-root', type=str, default='pretrained')
parser.add_argument('--row-count', type=int, default=16) # for debugging parser.add_argument('--row-count', type=int, default=16) # for debugging
def ascii_from_image(image: Image.Image, size: int) -> str: def ascii_from_image(image: Image.Image, size: int = 128) -> str:
rgb_pixels = image.resize((size, int(0.55 * size))).convert('L').getdata() rgb_pixels = image.resize((size, int(0.55 * size))).convert('L').getdata()
chars = list('.,;/IOX') chars = list('.,;/IOX')
chars = [chars[i * len(chars) // 256] for i in rgb_pixels] chars = [chars[i * len(chars) // 256] for i in rgb_pixels]
@ -57,12 +56,13 @@ def generate_image(
text, text,
seed, seed,
grid_size ** 2, grid_size ** 2,
row_count row_count,
is_verbose=True
) )
image_tokens = image_tokens[:, :token_count].to('cpu').detach().numpy() image_tokens = image_tokens[:, :token_count].to('cpu').detach().numpy()
print('image tokens', image_tokens) print('image tokens', image_tokens)
else: else:
image = model.generate_image(text, seed, grid_size) image = model.generate_image(text, seed, grid_size, is_verbose=True)
save_image(image, image_path) save_image(image, image_path)
print(ascii_from_image(image, size=128)) print(ascii_from_image(image, size=128))

View File

@ -1,10 +1,11 @@
import os import os
from PIL import Image from PIL import Image
import numpy import numpy
from torch import LongTensor from torch import LongTensor, FloatTensor
import torch import torch
import json import json
import requests import requests
from typing import Callable, Tuple
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count()) torch.set_num_threads(os.cpu_count())
@ -26,7 +27,6 @@ class MinDalle:
self.is_reusable = is_reusable self.is_reusable = is_reusable
self.is_verbose = is_verbose self.is_verbose = is_verbose
self.text_token_count = 64 self.text_token_count = 64
self.image_token_count = 256
self.layer_count = 24 if is_mega else 12 self.layer_count = 24 if is_mega else 12
self.attention_head_count = 32 if is_mega else 16 self.attention_head_count = 32 if is_mega else 16
self.embed_count = 2048 if is_mega else 1024 self.embed_count = 2048 if is_mega else 1024
@ -91,7 +91,7 @@ class MinDalle:
vocab = json.load(f) vocab = json.load(f)
with open(self.merges_path, 'r', encoding='utf8') as f: with open(self.merges_path, 'r', encoding='utf8') as f:
merges = f.read().split("\n")[1:-1] merges = f.read().split("\n")[1:-1]
self.tokenizer = TextTokenizer(vocab, merges, is_verbose=self.is_verbose) self.tokenizer = TextTokenizer(vocab, merges)
def init_encoder(self): def init_encoder(self):
@ -117,7 +117,6 @@ class MinDalle:
if not is_downloaded: self.download_decoder() if not is_downloaded: self.download_decoder()
if self.is_verbose: print("initializing DalleBartDecoder") if self.is_verbose: print("initializing DalleBartDecoder")
self.decoder = DalleBartDecoder( self.decoder = DalleBartDecoder(
image_token_count = self.image_token_count,
image_vocab_count = self.image_vocab_count, image_vocab_count = self.image_vocab_count,
attention_head_count = self.attention_head_count, attention_head_count = self.attention_head_count,
embed_count = self.embed_count, embed_count = self.embed_count,
@ -142,16 +141,37 @@ class MinDalle:
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
def image_from_tokens(
self,
grid_size: int,
image_tokens: LongTensor,
is_verbose: bool = False
) -> Image.Image:
if not self.is_reusable: del self.decoder
if torch.cuda.is_available(): torch.cuda.empty_cache()
if not self.is_reusable: self.init_detokenizer()
if is_verbose: print("detokenizing image")
images = self.detokenizer.forward(image_tokens).to(torch.uint8)
if not self.is_reusable: del self.detokenizer
images = images.reshape([grid_size] * 2 + list(images.shape[1:]))
image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2)
image = Image.fromarray(image.to('cpu').detach().numpy())
return image
def generate_image_tokens( def generate_image_tokens(
self, self,
text: str, text: str,
seed: int, seed: int,
image_count: int, grid_size: int,
row_count: int row_count: int,
mid_count: int = None,
handle_intermediate_image: Callable[[int, Image.Image], None] = None,
is_verbose: bool = False
) -> LongTensor: ) -> LongTensor:
if self.is_verbose: print("tokenizing text") if is_verbose: print("tokenizing text")
tokens = self.tokenizer.tokenize(text) tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
if self.is_verbose: print("text tokens", tokens) if is_verbose: print("text tokens", tokens)
text_tokens = numpy.ones((2, 64), dtype=numpy.int32) text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
text_tokens[0, :2] = [tokens[0], tokens[-1]] text_tokens[0, :2] = [tokens[0], tokens[-1]]
text_tokens[1, :len(tokens)] = tokens text_tokens[1, :len(tokens)] = tokens
@ -160,40 +180,57 @@ class MinDalle:
if torch.cuda.is_available(): text_tokens = text_tokens.cuda() if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
if not self.is_reusable: self.init_encoder() if not self.is_reusable: self.init_encoder()
if self.is_verbose: print("encoding text tokens") if is_verbose: print("encoding text tokens")
encoder_state = self.encoder.forward(text_tokens) encoder_state = self.encoder.forward(text_tokens)
if not self.is_reusable: del self.encoder if not self.is_reusable: del self.encoder
if torch.cuda.is_available(): torch.cuda.empty_cache() if torch.cuda.is_available(): torch.cuda.empty_cache()
if not self.is_reusable: self.init_decoder() if not self.is_reusable: self.init_decoder()
if self.is_verbose: print("sampling image tokens")
if seed > 0: torch.manual_seed(seed) encoder_state, attention_mask, attention_state, image_tokens = (
image_tokens = self.decoder.forward( self.decoder.decode_initial(
image_count, seed,
row_count, grid_size ** 2,
text_tokens, text_tokens,
encoder_state encoder_state
)
) )
if not self.is_reusable: del self.decoder
return image_tokens for row_index in range(row_count):
if is_verbose:
print('sampling row {} of {}'.format(row_index + 1, row_count))
attention_state, image_tokens = self.decoder.decode_row(
row_index,
encoder_state,
attention_mask,
attention_state,
image_tokens
)
if mid_count is not None:
if ((row_index + 1) * mid_count) % row_count == 0:
tokens = image_tokens[:, 1:]
image = self.image_from_tokens(grid_size, tokens, is_verbose)
handle_intermediate_image(row_index, image)
return image_tokens[:, 1:]
def generate_image( def generate_image(
self, self,
text: str, text: str,
seed: int = -1, seed: int = -1,
grid_size: int = 1 grid_size: int = 1,
mid_count: int = None,
handle_intermediate_image: Callable[[Image.Image], None] = None,
is_verbose: bool = False
) -> Image.Image: ) -> Image.Image:
image_count = grid_size ** 2 image_tokens = self.generate_image_tokens(
row_count = 16 text,
image_tokens = self.generate_image_tokens(text, seed, image_count, row_count) seed,
if torch.cuda.is_available(): torch.cuda.empty_cache() grid_size,
if not self.is_reusable: self.init_detokenizer() row_count = 16,
if self.is_verbose: print("detokenizing image") mid_count = mid_count,
images = self.detokenizer.forward(image_tokens).to(torch.uint8) handle_intermediate_image = handle_intermediate_image,
if not self.is_reusable: del self.detokenizer is_verbose = is_verbose
images = images.reshape([grid_size] * 2 + list(images.shape[1:])) )
image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2) return self.image_from_tokens(grid_size, image_tokens, is_verbose)
image = Image.fromarray(image.to('cpu').detach().numpy())
if torch.cuda.is_available(): torch.cuda.empty_cache()
return image

View File

@ -5,6 +5,9 @@ torch.set_grad_enabled(False)
from .dalle_bart_encoder import GLU, AttentionBase from .dalle_bart_encoder import GLU, AttentionBase
IMAGE_TOKEN_COUNT = 256
BLANK_TOKEN = 6965
class DecoderCrossAttention(AttentionBase): class DecoderCrossAttention(AttentionBase):
def forward( def forward(
@ -20,9 +23,9 @@ class DecoderCrossAttention(AttentionBase):
class DecoderSelfAttention(AttentionBase): class DecoderSelfAttention(AttentionBase):
def __init__(self, head_count: int, embed_count: int, token_count: int): def __init__(self, head_count: int, embed_count: int):
super().__init__(head_count, embed_count) super().__init__(head_count, embed_count)
token_indices = torch.arange(token_count) token_indices = torch.arange(IMAGE_TOKEN_COUNT)
if torch.cuda.is_available(): token_indices = token_indices.cuda() if torch.cuda.is_available(): token_indices = token_indices.cuda()
self.token_indices = token_indices self.token_indices = token_indices
@ -48,19 +51,13 @@ class DecoderSelfAttention(AttentionBase):
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
image_token_count: int,
head_count: int, head_count: int,
embed_count: int, embed_count: int,
glu_embed_count: int glu_embed_count: int
): ):
super().__init__() super().__init__()
self.image_token_count = image_token_count
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
self.self_attn = DecoderSelfAttention( self.self_attn = DecoderSelfAttention(head_count, embed_count)
head_count,
embed_count,
image_token_count
)
self.self_attn_layer_norm = nn.LayerNorm(embed_count) self.self_attn_layer_norm = nn.LayerNorm(embed_count)
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count) self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.encoder_attn = DecoderCrossAttention(head_count, embed_count) self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
@ -110,7 +107,6 @@ class DalleBartDecoder(nn.Module):
def __init__( def __init__(
self, self,
image_vocab_count: int, image_vocab_count: int,
image_token_count: int,
embed_count: int, embed_count: int,
attention_head_count: int, attention_head_count: int,
glu_embed_count: int, glu_embed_count: int,
@ -121,12 +117,10 @@ class DalleBartDecoder(nn.Module):
self.layer_count = layer_count self.layer_count = layer_count
self.embed_count = embed_count self.embed_count = embed_count
self.condition_factor = 10.0 self.condition_factor = 10.0
self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count) self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count)
self.layers: List[DecoderLayer] = nn.ModuleList([ self.layers: List[DecoderLayer] = nn.ModuleList([
DecoderLayer( DecoderLayer(
image_token_count,
attention_head_count, attention_head_count,
embed_count, embed_count,
glu_embed_count glu_embed_count
@ -137,7 +131,7 @@ class DalleBartDecoder(nn.Module):
self.final_ln = nn.LayerNorm(embed_count) self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False) self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
self.zero_prob = torch.zeros([1]) self.zero_prob = torch.zeros([1])
self.token_indices = torch.arange(self.image_token_count) self.token_indices = torch.arange(IMAGE_TOKEN_COUNT)
self.start_token = torch.tensor([start_token]).to(torch.long) self.start_token = torch.tensor([start_token]).to(torch.long)
if torch.cuda.is_available(): if torch.cuda.is_available():
self.zero_prob = self.zero_prob.cuda() self.zero_prob = self.zero_prob.cuda()
@ -188,8 +182,8 @@ class DalleBartDecoder(nn.Module):
def decode_row( def decode_row(
self, self,
row_index: int, row_index: int,
attention_mask: BoolTensor,
encoder_state: FloatTensor, encoder_state: FloatTensor,
attention_mask: BoolTensor,
attention_state: FloatTensor, attention_state: FloatTensor,
image_tokens_sequence: LongTensor image_tokens_sequence: LongTensor
) -> Tuple[FloatTensor, LongTensor]: ) -> Tuple[FloatTensor, LongTensor]:
@ -202,19 +196,18 @@ class DalleBartDecoder(nn.Module):
prev_tokens = image_tokens_sequence[:, i], prev_tokens = image_tokens_sequence[:, i],
token_index = self.token_indices[[i]] token_index = self.token_indices[[i]]
) )
image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0] image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0]
return attention_state, image_tokens_sequence return attention_state, image_tokens_sequence
def forward( def decode_initial(
self, self,
seed: int,
image_count: int, image_count: int,
row_count: int,
text_tokens: LongTensor, text_tokens: LongTensor,
encoder_state: FloatTensor encoder_state: FloatTensor
) -> LongTensor: ) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
expanded_indices = [0] * image_count + [1] * image_count expanded_indices = [0] * image_count + [1] * image_count
text_tokens = text_tokens[expanded_indices] text_tokens = text_tokens[expanded_indices]
encoder_state = encoder_state[expanded_indices] encoder_state = encoder_state[expanded_indices]
@ -223,13 +216,13 @@ class DalleBartDecoder(nn.Module):
attention_state_shape = ( attention_state_shape = (
self.layer_count, self.layer_count,
image_count * 4, image_count * 4,
self.image_token_count, IMAGE_TOKEN_COUNT,
self.embed_count self.embed_count
) )
attention_state = torch.zeros(attention_state_shape) attention_state = torch.zeros(attention_state_shape)
image_tokens_sequence = torch.full( image_tokens_sequence = torch.full(
(image_count, self.image_token_count + 1), (image_count, IMAGE_TOKEN_COUNT + 1),
6965, # black token BLANK_TOKEN,
dtype=torch.long dtype=torch.long
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -238,13 +231,6 @@ class DalleBartDecoder(nn.Module):
image_tokens_sequence[:, 0] = self.start_token[0] image_tokens_sequence[:, 0] = self.start_token[0]
for row_index in range(row_count): if seed > 0: torch.manual_seed(seed)
attention_state, image_tokens_sequence = self.decode_row(
row_index,
attention_mask,
encoder_state,
attention_state,
image_tokens_sequence
)
return image_tokens_sequence[:, 1:] return encoder_state, attention_mask, attention_state, image_tokens_sequence

View File

@ -2,13 +2,12 @@ from math import inf
from typing import List, Tuple from typing import List, Tuple
class TextTokenizer: class TextTokenizer:
def __init__(self, vocab: dict, merges: List[str], is_verbose: bool = True): def __init__(self, vocab: dict, merges: List[str]):
self.is_verbose = is_verbose
self.token_from_subword = vocab self.token_from_subword = vocab
pairs = [tuple(pair.split()) for pair in merges] pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
def tokenize(self, text: str) -> List[int]: def tokenize(self, text: str, is_verbose: bool = False) -> List[int]:
sep_token = self.token_from_subword['</s>'] sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>'] cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>'] unk_token = self.token_from_subword['<unk>']
@ -16,11 +15,11 @@ class TextTokenizer:
tokens = [ tokens = [
self.token_from_subword.get(subword, unk_token) self.token_from_subword.get(subword, unk_token)
for word in text.split(" ") if len(word) > 0 for word in text.split(" ") if len(word) > 0
for subword in self.get_byte_pair_encoding(word) for subword in self.get_byte_pair_encoding(word, is_verbose)
] ]
return [cls_token] + tokens + [sep_token] return [cls_token] + tokens + [sep_token]
def get_byte_pair_encoding(self, word: str) -> List[str]: def get_byte_pair_encoding(self, word: str, is_verbose: bool) -> List[str]:
def get_pair_rank(pair: Tuple[str, str]) -> int: def get_pair_rank(pair: Tuple[str, str]) -> int:
return self.rank_from_pair.get(pair, inf) return self.rank_from_pair.get(pair, inf)
@ -36,5 +35,5 @@ class TextTokenizer:
(subwords[i + 2:] if i + 2 < len(subwords) else []) (subwords[i + 2:] if i + 2 < len(subwords) else [])
) )
if self.is_verbose: print(subwords) if is_verbose: print(subwords)
return subwords return subwords

View File

@ -5,7 +5,7 @@ setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
long_description=(Path(__file__).parent / "README.rst").read_text(), long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.2.17', version='0.2.21',
author='Brett Kuprel', author='Brett Kuprel',
author_email='brkuprel@gmail.com', author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle', url='https://github.com/kuprel/min-dalle',