added grid_size parameter to generate a grid of images

This commit is contained in:
Brett Kuprel 2022-07-02 08:45:49 -04:00
parent e0386f991c
commit 1eb56737d8
6 changed files with 87 additions and 69 deletions

View File

@ -11,9 +11,10 @@ parser.add_argument('--no-mega', dest='mega', action='store_false')
parser.set_defaults(mega=False) parser.set_defaults(mega=False)
parser.add_argument('--text', type=str, default='alien life') parser.add_argument('--text', type=str, default='alien life')
parser.add_argument('--seed', type=int, default=-1) parser.add_argument('--seed', type=int, default=-1)
parser.add_argument('--image_path', type=str, default='generated') parser.add_argument('--grid-size', type=int, default=1)
parser.add_argument('--models_root', type=str, default='pretrained') parser.add_argument('--image-path', type=str, default='generated')
parser.add_argument('--token_count', type=int, default=256) # for debugging parser.add_argument('--models-root', type=str, default='pretrained')
parser.add_argument('--token-count', type=int, default=256) # for debugging
def ascii_from_image(image: Image.Image, size: int) -> str: def ascii_from_image(image: Image.Image, size: int) -> str:
@ -38,6 +39,7 @@ def generate_image(
is_mega: bool, is_mega: bool,
text: str, text: str,
seed: int, seed: int,
grid_size: int,
image_path: str, image_path: str,
models_root: str, models_root: str,
token_count: int token_count: int
@ -51,10 +53,10 @@ def generate_image(
) )
if token_count < 256: if token_count < 256:
image_tokens = model.generate_image_tokens(text, seed) image_tokens = model.generate_image_tokens(text, seed, grid_size ** 2)
print('image tokens', list(image_tokens.to('cpu').detach().numpy())) print('image tokens', image_tokens.to('cpu').detach().numpy())
else: else:
image = model.generate_image(text, seed) image = model.generate_image(text, seed, grid_size)
save_image(image, image_path) save_image(image, image_path)
print(ascii_from_image(image, size=128)) print(ascii_from_image(image, size=128))
@ -66,6 +68,7 @@ if __name__ == '__main__':
is_mega=args.mega, is_mega=args.mega,
text=args.text, text=args.text,
seed=args.seed, seed=args.seed,
grid_size=args.grid_size,
image_path=args.image_path, image_path=args.image_path,
models_root=args.models_root, models_root=args.models_root,
token_count=args.token_count token_count=args.token_count

View File

@ -1,4 +1,5 @@
import os import os
from re import I
from PIL import Image from PIL import Image
import numpy import numpy
from torch import LongTensor from torch import LongTensor
@ -28,7 +29,6 @@ class MinDalle:
self.is_reusable = is_reusable self.is_reusable = is_reusable
self.is_verbose = is_verbose self.is_verbose = is_verbose
self.sample_token_count = sample_token_count self.sample_token_count = sample_token_count
self.batch_count = 2
self.text_token_count = 64 self.text_token_count = 64
self.image_token_count = 256 self.image_token_count = 256
self.layer_count = 24 if is_mega else 12 self.layer_count = 24 if is_mega else 12
@ -128,8 +128,7 @@ class MinDalle:
embed_count = self.embed_count, embed_count = self.embed_count,
glu_embed_count = self.glu_embed_count, glu_embed_count = self.glu_embed_count,
layer_count = self.layer_count, layer_count = self.layer_count,
start_token = self.image_vocab_count, start_token = self.image_vocab_count
batch_count = self.batch_count
) )
params = torch.load(self.decoder_params_path) params = torch.load(self.decoder_params_path)
self.decoder.load_state_dict(params, strict=False) self.decoder.load_state_dict(params, strict=False)
@ -148,7 +147,12 @@ class MinDalle:
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
def generate_image_tokens(self, text: str, seed: int) -> LongTensor: def generate_image_tokens(
self,
text: str,
seed: int,
image_count: int
) -> LongTensor:
if self.is_verbose: print("tokenizing text") if self.is_verbose: print("tokenizing text")
tokens = self.tokenizer.tokenize(text) tokens = self.tokenizer.tokenize(text)
if self.is_verbose: print("text tokens", tokens) if self.is_verbose: print("text tokens", tokens)
@ -166,18 +170,29 @@ class MinDalle:
if not self.is_reusable: self.init_decoder() if not self.is_reusable: self.init_decoder()
if self.is_verbose: print("sampling image tokens") if self.is_verbose: print("sampling image tokens")
if seed < 0: seed = random.randint(0, 2 ** 31) if seed > 0: torch.manual_seed(seed)
torch.manual_seed(seed) image_tokens = self.decoder.forward(
image_tokens = self.decoder.forward(text_tokens, encoder_state) image_count,
text_tokens,
encoder_state
)
if not self.is_reusable: del self.decoder if not self.is_reusable: del self.decoder
return image_tokens return image_tokens
def generate_image(self, text: str, seed: int) -> Image.Image: def generate_image(
image_tokens = self.generate_image_tokens(text, seed) self,
text: str,
seed: int = -1,
grid_size: int = 1
) -> Image.Image:
image_count = grid_size ** 2
image_tokens = self.generate_image_tokens(text, seed, image_count)
if not self.is_reusable: self.init_detokenizer() if not self.is_reusable: self.init_detokenizer()
if self.is_verbose: print("detokenizing image") if self.is_verbose: print("detokenizing image")
image = self.detokenizer.forward(image_tokens).to(torch.uint8) images = self.detokenizer.forward(image_tokens).to(torch.uint8)
if not self.is_reusable: del self.detokenizer if not self.is_reusable: del self.detokenizer
images = images.reshape([grid_size] * 2 + list(images.shape[1:]))
image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2)
image = Image.fromarray(image.to('cpu').detach().numpy()) image = Image.fromarray(image.to('cpu').detach().numpy())
return image return image

View File

@ -1,4 +1,3 @@
from typing import List, Tuple
import torch import torch
from torch import LongTensor, nn, FloatTensor, BoolTensor from torch import LongTensor, nn, FloatTensor, BoolTensor
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
@ -26,7 +25,7 @@ class DecoderSelfAttention(AttentionBase):
attention_state: FloatTensor, attention_state: FloatTensor,
attention_mask: BoolTensor, attention_mask: BoolTensor,
token_mask: BoolTensor token_mask: BoolTensor
) -> Tuple[FloatTensor, FloatTensor]: ) -> tuple[FloatTensor, FloatTensor]:
keys = self.k_proj.forward(decoder_state) keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state) values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state) queries = self.q_proj.forward(decoder_state)
@ -71,13 +70,13 @@ class DecoderLayer(nn.Module):
attention_state: FloatTensor, attention_state: FloatTensor,
attention_mask: BoolTensor, attention_mask: BoolTensor,
token_index: LongTensor token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]: ) -> tuple[FloatTensor, FloatTensor]:
# Self Attention # Self Attention
residual = decoder_state residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
self_attn_mask = self.token_indices < token_index + 1 self_attn_mask = self.token_indices < token_index + 1
self_attn_mask = self_attn_mask[None][[0] * decoder_state.shape[0]]
token_mask = self.token_indices == token_index token_mask = self.token_indices == token_index
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
decoder_state, attention_state = self.self_attn.forward( decoder_state, attention_state = self.self_attn.forward(
decoder_state, decoder_state,
attention_state, attention_state,
@ -116,17 +115,17 @@ class DalleBartDecoder(nn.Module):
attention_head_count: int, attention_head_count: int,
glu_embed_count: int, glu_embed_count: int,
layer_count: int, layer_count: int,
batch_count: int,
start_token: int start_token: int
): ):
super().__init__() super().__init__()
self.layer_count = layer_count self.layer_count = layer_count
self.embed_count = embed_count
self.sample_token_count = sample_token_count self.sample_token_count = sample_token_count
self.condition_factor = 10.0 self.condition_factor = 10.0
self.image_token_count = image_token_count self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count)
self.layers: List[DecoderLayer] = nn.ModuleList([ self.layers: list[DecoderLayer] = nn.ModuleList([
DecoderLayer( DecoderLayer(
image_token_count, image_token_count,
attention_head_count, attention_head_count,
@ -138,12 +137,6 @@ class DalleBartDecoder(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_count) self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count) self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False) self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
self.attention_state_shape = (
layer_count,
2 * batch_count,
image_token_count,
embed_count
)
self.zero_prob = torch.zeros([1]) self.zero_prob = torch.zeros([1])
self.token_indices = torch.arange(self.sample_token_count) self.token_indices = torch.arange(self.sample_token_count)
self.start_token = torch.tensor([start_token]).to(torch.long) self.start_token = torch.tensor([start_token]).to(torch.long)
@ -155,17 +148,16 @@ class DalleBartDecoder(nn.Module):
def decode_step( def decode_step(
self, self,
text_tokens: LongTensor, attention_mask: BoolTensor,
encoder_state: FloatTensor, encoder_state: FloatTensor,
attention_state: FloatTensor, attention_state: FloatTensor,
prev_token: LongTensor, prev_tokens: LongTensor,
token_index: LongTensor token_index: LongTensor
) -> Tuple[LongTensor, FloatTensor]: ) -> tuple[LongTensor, FloatTensor]:
attention_mask = text_tokens.not_equal(1) image_count = encoder_state.shape[0] // 2
batch_count = encoder_state.shape[0] token_index_batched = token_index[[0] * image_count * 2]
prev_token_batched = torch.cat([prev_token] * batch_count) prev_tokens = prev_tokens[list(range(image_count)) * 2]
token_index_batched = torch.cat([token_index] * batch_count) decoder_state = self.embed_tokens.forward(prev_tokens)
decoder_state = self.embed_tokens.forward(prev_token_batched)
decoder_state += self.embed_positions.forward(token_index_batched) decoder_state += self.embed_positions.forward(token_index_batched)
decoder_state = self.layernorm_embedding.forward(decoder_state) decoder_state = self.layernorm_embedding.forward(decoder_state)
decoder_state = decoder_state[:, None] decoder_state = decoder_state[:, None]
@ -182,38 +174,52 @@ class DalleBartDecoder(nn.Module):
decoder_state = self.final_ln(decoder_state) decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state) logits = self.lm_head(decoder_state)
a = self.condition_factor a = self.condition_factor
logits: FloatTensor = (1 - a) * logits[0, -1] + a * logits[1, -1] logits: FloatTensor = (
logits[:image_count, -1] * (1 - a) +
logits[image_count:, -1] * a
)
top_logits, _ = logits.topk(50, dim=-1) top_logits, _ = logits.topk(50, dim=-1)
probs = torch.where( probs = torch.where(
logits < top_logits[-1], logits < top_logits[:, [-1]],
self.zero_prob, self.zero_prob,
torch.exp(logits - top_logits[0]) torch.exp(logits - top_logits[:, [0]])
) )
return probs, torch.stack(attention_states_new) return probs, torch.stack(attention_states_new)
def forward( def forward(
self, self,
image_count: int,
text_tokens: LongTensor, text_tokens: LongTensor,
encoder_state: FloatTensor encoder_state: FloatTensor
) -> LongTensor: ) -> LongTensor:
image_tokens: List[LongTensor] = [] expanded_indices = [0] * image_count + [1] * image_count
attention_state = torch.zeros(self.attention_state_shape) text_tokens = text_tokens[expanded_indices]
if torch.cuda.is_available(): encoder_state = encoder_state[expanded_indices]
attention_state = attention_state.cuda() attention_mask = text_tokens.not_equal(1)
image_token = self.start_token
attention_state_shape = (
self.layer_count,
image_count * 4,
self.image_token_count,
self.embed_count
)
attention_state = torch.zeros(attention_state_shape)
if torch.cuda.is_available(): attention_state = attention_state.cuda()
image_tokens = self.start_token[[0] * image_count]
image_tokens_sequence: list[LongTensor] = []
for i in range(self.sample_token_count): for i in range(self.sample_token_count):
probs, attention_state = self.decode_step( probs, attention_state = self.decode_step(
text_tokens = text_tokens, attention_mask = attention_mask,
encoder_state = encoder_state, encoder_state = encoder_state,
attention_state = attention_state, attention_state = attention_state,
prev_token = image_token, prev_tokens = image_tokens,
token_index = self.token_indices[[i]] token_index = self.token_indices[[i]]
) )
image_token = torch.multinomial(probs, 1) image_tokens = torch.multinomial(probs, 1)[:, 0]
image_tokens += [image_token] image_tokens_sequence += [image_tokens]
return torch.cat(image_tokens) return torch.stack(image_tokens_sequence).T

View File

@ -1,4 +1,3 @@
from typing import List
import torch import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
@ -121,7 +120,7 @@ class DalleBartEncoder(nn.Module):
super().__init__() super().__init__()
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
self.embed_positions = nn.Embedding(text_token_count, embed_count) self.embed_positions = nn.Embedding(text_token_count, embed_count)
self.layers: List[EncoderLayer] = nn.ModuleList([ self.layers: list[EncoderLayer] = nn.ModuleList([
EncoderLayer( EncoderLayer(
embed_count = embed_count, embed_count = embed_count,
head_count = attention_head_count, head_count = attention_head_count,
@ -137,8 +136,7 @@ class DalleBartEncoder(nn.Module):
def forward(self, text_tokens: LongTensor) -> FloatTensor: def forward(self, text_tokens: LongTensor) -> FloatTensor:
attention_mask = text_tokens.not_equal(1) attention_mask = text_tokens.not_equal(1)
batch_count = text_tokens.shape[0] pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
pose_tokens = torch.stack([self.token_indices] * batch_count)
encoder_state = ( encoder_state = (
self.embed_tokens.forward(text_tokens) + self.embed_tokens.forward(text_tokens) +
self.embed_positions.forward(pose_tokens) self.embed_positions.forward(pose_tokens)

View File

@ -3,8 +3,6 @@ from torch import Tensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
BATCH_COUNT: int = 1
class ResnetBlock(Module): class ResnetBlock(Module):
def __init__(self, log2_count_in: int, log2_count_out: int): def __init__(self, log2_count_in: int, log2_count_out: int):
@ -42,22 +40,22 @@ class AttentionBlock(Module):
self.proj_out = Conv2d(n, n, 1) self.proj_out = Conv2d(n, n, 1)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
n = 2 ** 9 n, m = 2 ** 9, x.shape[0]
h = x h = x
h = self.norm(h) h = self.norm(h)
q = self.q.forward(h) q = self.q.forward(h)
k = self.k.forward(h) k = self.k.forward(h)
v = self.v.forward(h) v = self.v.forward(h)
q = q.reshape(BATCH_COUNT, n, 2 ** 8) q = q.reshape(m, n, 2 ** 8)
q = q.permute(0, 2, 1) q = q.permute(0, 2, 1)
k = k.reshape(BATCH_COUNT, n, 2 ** 8) k = k.reshape(m, n, 2 ** 8)
w = torch.bmm(q, k) w = torch.bmm(q, k)
w /= n ** 0.5 w /= n ** 0.5
w = torch.softmax(w, dim=2) w = torch.softmax(w, dim=2)
v = v.reshape(BATCH_COUNT, n, 2 ** 8) v = v.reshape(m, n, 2 ** 8)
w = w.permute(0, 2, 1) w = w.permute(0, 2, 1)
h = torch.bmm(v, w) h = torch.bmm(v, w)
h = h.reshape(BATCH_COUNT, n, 2 ** 4, 2 ** 4) h = h.reshape(m, n, 2 ** 4, 2 ** 4)
h = self.proj_out.forward(h) h = self.proj_out.forward(h)
return x + h return x + h
@ -169,10 +167,10 @@ class VQGanDetokenizer(Module):
def forward(self, z: Tensor) -> Tensor: def forward(self, z: Tensor) -> Tensor:
z = self.embedding.forward(z) z = self.embedding.forward(z)
z = z.view((BATCH_COUNT, 2 ** 4, 2 ** 4, 2 ** 8)) z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
z = z.permute(0, 3, 1, 2).contiguous() z = z.permute(0, 3, 1, 2).contiguous()
z = self.post_quant_conv.forward(z) z = self.post_quant_conv.forward(z)
z = self.decoder.forward(z) z = self.decoder.forward(z)
z = z.permute(0, 2, 3, 1) z = z.permute(0, 2, 3, 1)
z = z.clip(0.0, 1.0) * 255 z = z.clip(0.0, 1.0) * 255
return z[0] return z

View File

@ -1,15 +1,13 @@
from math import inf from math import inf
from typing import List, Tuple
class TextTokenizer: class TextTokenizer:
def __init__(self, vocab: dict, merges: List[str], is_verbose: bool = True): def __init__(self, vocab: dict, merges: list[str], is_verbose: bool = True):
self.is_verbose = is_verbose self.is_verbose = is_verbose
self.token_from_subword = vocab self.token_from_subword = vocab
pairs = [tuple(pair.split()) for pair in merges] pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
def tokenize(self, text: str) -> List[int]: def tokenize(self, text: str) -> list[int]:
sep_token = self.token_from_subword['</s>'] sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>'] cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>'] unk_token = self.token_from_subword['<unk>']
@ -21,8 +19,8 @@ class TextTokenizer:
] ]
return [cls_token] + tokens + [sep_token] return [cls_token] + tokens + [sep_token]
def get_byte_pair_encoding(self, word: str) -> List[str]: def get_byte_pair_encoding(self, word: str) -> list[str]:
def get_pair_rank(pair: Tuple[str, str]) -> int: def get_pair_rank(pair: tuple[str, str]) -> int:
return self.rank_from_pair.get(pair, inf) return self.rank_from_pair.get(pair, inf)
subwords = [chr(ord(" ") + 256)] + list(word) subwords = [chr(ord(" ") + 256)] + list(word)