optionally tile images in token space

This commit is contained in:
Brett Kuprel 2022-07-17 07:33:31 -04:00
parent 39376c9cf2
commit 798e6ac5a3
6 changed files with 145 additions and 107 deletions

4
README.md vendored
View File

@ -62,7 +62,7 @@ display(image)
Credit to [@hardmaru](https://twitter.com/hardmaru) for the [example](https://twitter.com/hardmaru/status/1544354119527596034) Credit to [@hardmaru](https://twitter.com/hardmaru) for the [example](https://twitter.com/hardmaru/status/1544354119527596034)
### Saving Individual Images <!-- ### Saving Individual Images
The images can also be generated as a `FloatTensor` in case you want to process them manually. The images can also be generated as a `FloatTensor` in case you want to process them manually.
```python ```python
@ -85,7 +85,7 @@ Then image $i$ can be coverted to a PIL.Image and saved
```python ```python
image = Image.fromarray(images[i]) image = Image.fromarray(images[i])
image.save('image_{}.png'.format(i)) image.save('image_{}.png'.format(i))
``` ``` -->
### Progressive Outputs ### Progressive Outputs

2
cog.yaml vendored
View File

@ -6,7 +6,7 @@ build:
- "libgl1-mesa-glx" - "libgl1-mesa-glx"
- "libglib2.0-0" - "libglib2.0-0"
python_packages: python_packages:
- "min-dalle==0.3.13" - "min-dalle==0.3.15"
run: run:
- pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html - pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html

View File

@ -156,39 +156,34 @@ class MinDalle:
self.detokenizer = self.detokenizer.to(device=self.device) self.detokenizer = self.detokenizer.to(device=self.device)
def images_from_tokens( def image_grid_from_tokens(
self, self,
image_tokens: LongTensor, image_tokens: LongTensor,
is_seamless: bool,
is_verbose: bool = False is_verbose: bool = False
) -> FloatTensor: ) -> FloatTensor:
if not self.is_reusable: del self.decoder if not self.is_reusable: del self.decoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
if not self.is_reusable: self.init_detokenizer() if not self.is_reusable: self.init_detokenizer()
if is_verbose: print("detokenizing image") if is_verbose: print("detokenizing image")
images = self.detokenizer.forward(image_tokens).to(torch.uint8) images = self.detokenizer.forward(is_seamless, image_tokens)
if not self.is_reusable: del self.detokenizer if not self.is_reusable: del self.detokenizer
return images return images
def grid_from_images(self, images: FloatTensor) -> Image.Image: def generate_image_stream(
grid_size = int(sqrt(images.shape[0]))
images = images.reshape([grid_size] * 2 + list(images.shape[1:]))
image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2)
image = Image.fromarray(image.to('cpu').numpy())
return image
def generate_images_stream(
self, self,
text: str, text: str,
seed: int, seed: int,
image_count: int, grid_size: int,
progressive_outputs: bool = False, progressive_outputs: bool = False,
is_seamless: bool = False,
temperature: float = 1, temperature: float = 1,
top_k: int = 256, top_k: int = 256,
supercondition_factor: int = 16, supercondition_factor: int = 16,
is_verbose: bool = False is_verbose: bool = False
) -> Iterator[FloatTensor]: ) -> Iterator[Image.Image]:
image_count = grid_size ** 2
if is_verbose: print("tokenizing text") if is_verbose: print("tokenizing text")
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
if len(tokens) > self.text_token_count: if len(tokens) > self.text_token_count:
@ -254,58 +249,13 @@ class MinDalle:
with torch.cuda.amp.autocast(dtype=torch.float32): with torch.cuda.amp.autocast(dtype=torch.float32):
if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256: if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
yield self.images_from_tokens( image = self.image_grid_from_tokens(
image_tokens=image_tokens[1:].T, image_tokens=image_tokens[1:].T,
is_seamless=is_seamless,
is_verbose=is_verbose is_verbose=is_verbose
) )
image = image.to(torch.uint8).to('cpu').numpy()
yield Image.fromarray(image)
def generate_image_stream(
self,
text: str,
seed: int,
grid_size: int,
progressive_outputs: bool = False,
temperature: float = 1,
top_k: int = 256,
supercondition_factor: int = 16,
is_verbose: bool = False
) -> Iterator[Image.Image]:
images_stream = self.generate_images_stream(
text=text,
seed=seed,
image_count=grid_size ** 2,
progressive_outputs=progressive_outputs,
temperature=temperature,
top_k=top_k,
supercondition_factor=supercondition_factor,
is_verbose=is_verbose
)
for images in images_stream:
yield self.grid_from_images(images)
def generate_images(
self,
text: str,
seed: int = -1,
image_count: int = 1,
temperature: float = 1,
top_k: int = 1024,
supercondition_factor: int = 16,
is_verbose: bool = False
) -> FloatTensor:
images_stream = self.generate_images_stream(
text=text,
seed=seed,
image_count=image_count,
temperature=temperature,
progressive_outputs=False,
top_k=top_k,
supercondition_factor=supercondition_factor,
is_verbose=is_verbose
)
return next(images_stream)
def generate_image( def generate_image(
@ -329,3 +279,54 @@ class MinDalle:
is_verbose=is_verbose is_verbose=is_verbose
) )
return next(image_stream) return next(image_stream)
# def images_from_image(image: Image.Image) -> FloatTensor:
# pass
# def generate_images_stream(
# self,
# text: str,
# seed: int,
# grid_size: int,
# progressive_outputs: bool = False,
# temperature: float = 1,
# top_k: int = 256,
# supercondition_factor: int = 16,
# is_verbose: bool = False
# ) -> Iterator[FloatTensor]:
# image_stream = self.generate_image_stream(
# text=text,
# seed=seed,
# image_count=grid_size ** 2,
# progressive_outputs=progressive_outputs,
# is_seamless=False,
# temperature=temperature,
# top_k=top_k,
# supercondition_factor=supercondition_factor,
# is_verbose=is_verbose
# )
# for image in image_stream:
# yield self.images_from_image(image)
# def generate_images(
# self,
# text: str,
# seed: int = -1,
# image_count: int = 1,
# temperature: float = 1,
# top_k: int = 1024,
# supercondition_factor: int = 16,
# is_verbose: bool = False
# ) -> FloatTensor:
# images_stream = self.generate_images_stream(
# text=text,
# seed=seed,
# image_count=image_count,
# temperature=temperature,
# progressive_outputs=False,
# top_k=top_k,
# supercondition_factor=supercondition_factor,
# is_verbose=is_verbose
# )
# return next(images_stream)

View File

@ -1,19 +1,20 @@
import torch import torch
from torch import nn
from torch import FloatTensor, LongTensor from torch import FloatTensor, LongTensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding from math import sqrt
class ResnetBlock(Module): class ResnetBlock(nn.Module):
def __init__(self, log2_count_in: int, log2_count_out: int): def __init__(self, log2_count_in: int, log2_count_out: int):
super().__init__() super().__init__()
m, n = 2 ** log2_count_in, 2 ** log2_count_out m, n = 2 ** log2_count_in, 2 ** log2_count_out
self.is_middle = m == n self.is_middle = m == n
self.norm1 = GroupNorm(2 ** 5, m) self.norm1 = nn.GroupNorm(2 ** 5, m)
self.conv1 = Conv2d(m, n, 3, padding=1) self.conv1 = nn.Conv2d(m, n, 3, padding=1)
self.norm2 = GroupNorm(2 ** 5, n) self.norm2 = nn.GroupNorm(2 ** 5, n)
self.conv2 = Conv2d(n, n, 3, padding=1) self.conv2 = nn.Conv2d(n, n, 3, padding=1)
if not self.is_middle: if not self.is_middle:
self.nin_shortcut = Conv2d(m, n, 1) self.nin_shortcut = nn.Conv2d(m, n, 1)
def forward(self, x: FloatTensor) -> FloatTensor: def forward(self, x: FloatTensor) -> FloatTensor:
h = x h = x
@ -28,38 +29,39 @@ class ResnetBlock(Module):
return x + h return x + h
class AttentionBlock(Module): class AttentionBlock(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
n = 2 ** 9 n = 2 ** 9
self.norm = GroupNorm(2 ** 5, n) self.norm = nn.GroupNorm(2 ** 5, n)
self.q = Conv2d(n, n, 1) self.q = nn.Conv2d(n, n, 1)
self.k = Conv2d(n, n, 1) self.k = nn.Conv2d(n, n, 1)
self.v = Conv2d(n, n, 1) self.v = nn.Conv2d(n, n, 1)
self.proj_out = Conv2d(n, n, 1) self.proj_out = nn.Conv2d(n, n, 1)
def forward(self, x: FloatTensor) -> FloatTensor: def forward(self, x: FloatTensor) -> FloatTensor:
n, m = 2 ** 9, x.shape[0] n, m = 2 ** 9, x.shape[0]
h = x h = x
h = self.norm(h) h = self.norm(h)
q = self.q.forward(h)
k = self.k.forward(h) k = self.k.forward(h)
v = self.v.forward(h) v = self.v.forward(h)
q = q.reshape(m, n, 2 ** 8) q = self.q.forward(h)
k = k.reshape(m, n, -1)
v = v.reshape(m, n, -1)
q = q.reshape(m, n, -1)
q = q.permute(0, 2, 1) q = q.permute(0, 2, 1)
k = k.reshape(m, n, 2 ** 8)
w = torch.bmm(q, k) w = torch.bmm(q, k)
w /= n ** 0.5 w /= n ** 0.5
w = torch.softmax(w, dim=2) w = torch.softmax(w, dim=2)
v = v.reshape(m, n, 2 ** 8)
w = w.permute(0, 2, 1) w = w.permute(0, 2, 1)
h = torch.bmm(v, w) h = torch.bmm(v, w)
h = h.reshape(m, n, 2 ** 4, 2 ** 4) token_count = int(sqrt(h.shape[-1]))
h = h.reshape(m, n, token_count, token_count)
h = self.proj_out.forward(h) h = self.proj_out.forward(h)
return x + h return x + h
class MiddleLayer(Module): class MiddleLayer(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.block_1 = ResnetBlock(9, 9) self.block_1 = ResnetBlock(9, 9)
@ -73,12 +75,12 @@ class MiddleLayer(Module):
return h return h
class Upsample(Module): class Upsample(nn.Module):
def __init__(self, log2_count): def __init__(self, log2_count):
super().__init__() super().__init__()
n = 2 ** log2_count n = 2 ** log2_count
self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2) self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
self.conv = Conv2d(n, n, 3, padding=1) self.conv = nn.Conv2d(n, n, 3, padding=1)
def forward(self, x: FloatTensor) -> FloatTensor: def forward(self, x: FloatTensor) -> FloatTensor:
x = self.upsample.forward(x.to(torch.float32)) x = self.upsample.forward(x.to(torch.float32))
@ -86,7 +88,7 @@ class Upsample(Module):
return x return x
class UpsampleBlock(Module): class UpsampleBlock(nn.Module):
def __init__( def __init__(
self, self,
log2_count_in: int, log2_count_in: int,
@ -97,19 +99,19 @@ class UpsampleBlock(Module):
super().__init__() super().__init__()
self.has_attention = has_attention self.has_attention = has_attention
self.has_upsample = has_upsample self.has_upsample = has_upsample
self.block = ModuleList([
self.block = nn.ModuleList([
ResnetBlock(log2_count_in, log2_count_out), ResnetBlock(log2_count_in, log2_count_out),
ResnetBlock(log2_count_out, log2_count_out), ResnetBlock(log2_count_out, log2_count_out),
ResnetBlock(log2_count_out, log2_count_out) ResnetBlock(log2_count_out, log2_count_out)
]) ])
if has_attention: if has_attention:
self.attn = ModuleList([ self.attn = nn.ModuleList([
AttentionBlock(), AttentionBlock(),
AttentionBlock(), AttentionBlock(),
AttentionBlock() AttentionBlock()
]) ])
else:
self.attn = ModuleList()
if has_upsample: if has_upsample:
self.upsample = Upsample(log2_count_out) self.upsample = Upsample(log2_count_out)
@ -125,14 +127,14 @@ class UpsampleBlock(Module):
return h return h
class Decoder(Module): class Decoder(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv_in = Conv2d(2 ** 8, 2 ** 9, 3, padding=1) self.conv_in = nn.Conv2d(2 ** 8, 2 ** 9, 3, padding=1)
self.mid = MiddleLayer() self.mid = MiddleLayer()
self.up = ModuleList([ self.up = nn.ModuleList([
UpsampleBlock(7, 7, False, False), UpsampleBlock(7, 7, False, False),
UpsampleBlock(8, 7, False, True), UpsampleBlock(8, 7, False, True),
UpsampleBlock(8, 8, False, True), UpsampleBlock(8, 8, False, True),
@ -140,8 +142,8 @@ class Decoder(Module):
UpsampleBlock(9, 9, True, True) UpsampleBlock(9, 9, True, True)
]) ])
self.norm_out = GroupNorm(2 ** 5, 2 ** 7) self.norm_out = nn.GroupNorm(2 ** 5, 2 ** 7)
self.conv_out = Conv2d(2 ** 7, 3, 3, padding=1) self.conv_out = nn.Conv2d(2 ** 7, 3, 3, padding=1)
def forward(self, z: FloatTensor) -> FloatTensor: def forward(self, z: FloatTensor) -> FloatTensor:
z = self.conv_in.forward(z) z = self.conv_in.forward(z)
@ -156,22 +158,40 @@ class Decoder(Module):
return z return z
class VQGanDetokenizer(Module): class VQGanDetokenizer(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
vocab_count, embed_count = 2 ** 14, 2 ** 8 vocab_count, embed_count = 2 ** 14, 2 ** 8
self.vocab_count = vocab_count self.vocab_count = vocab_count
self.embedding = Embedding(vocab_count, embed_count) self.embedding = nn.Embedding(vocab_count, embed_count)
self.post_quant_conv = Conv2d(embed_count, embed_count, 1) self.post_quant_conv = nn.Conv2d(embed_count, embed_count, 1)
self.decoder = Decoder() self.decoder = Decoder()
def forward(self, z: LongTensor) -> FloatTensor: def forward(self, is_seamless: bool, z: LongTensor) -> FloatTensor:
z.clamp_(0, self.vocab_count - 1) z.clamp_(0, self.vocab_count - 1)
z = self.embedding.forward(z) grid_size = int(sqrt(z.shape[0]))
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8)) token_count = grid_size * 2 ** 4
if is_seamless:
z = z.view([grid_size, grid_size, 2 ** 4, 2 ** 4])
z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2)
z = z.flatten().unsqueeze(1)
z = self.embedding.forward(z)
z = z.view((1, token_count, token_count, 2 ** 8))
else:
z = self.embedding.forward(z)
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
z = z.permute(0, 3, 1, 2).contiguous() z = z.permute(0, 3, 1, 2).contiguous()
z = self.post_quant_conv.forward(z) z = self.post_quant_conv.forward(z)
z = self.decoder.forward(z) z = self.decoder.forward(z)
z = z.permute(0, 2, 3, 1) z = z.permute(0, 2, 3, 1)
z = z.clip(0.0, 1.0) * 255 z = z.clip(0.0, 1.0) * 255
if is_seamless:
z = z[0]
else:
z = z.view([grid_size, grid_size, 2 ** 8, 2 ** 8, 3])
z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2)
return z return z

View File

@ -5,7 +5,7 @@ setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
# long_description=(Path(__file__).parent / "README.rst").read_text(), # long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.3.13', version='0.3.15',
author='Brett Kuprel', author='Brett Kuprel',
author_email='brkuprel@gmail.com', author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle', url='https://github.com/kuprel/min-dalle',

View File

@ -57,6 +57,7 @@ sv_prompt = tkinter.StringVar(value="artificial intelligence")
sv_temperature = tkinter.StringVar(value="1") sv_temperature = tkinter.StringVar(value="1")
sv_topk = tkinter.StringVar(value="128") sv_topk = tkinter.StringVar(value="128")
sv_supercond = tkinter.StringVar(value="16") sv_supercond = tkinter.StringVar(value="16")
bv_seamless = tkinter.BooleanVar(value=False)
def generate(): def generate():
# check fields # check fields
@ -75,6 +76,10 @@ def generate():
except: except:
sv_supercond.set("ERROR") sv_supercond.set("ERROR")
return return
try:
is_seamless = bool(bv_seamless.get())
except:
return
# and continue # and continue
global label_image_content global label_image_content
image_stream = model.generate_image_stream( image_stream = model.generate_image_stream(
@ -82,16 +87,22 @@ def generate():
grid_size=2, grid_size=2,
seed=-1, seed=-1,
progressive_outputs=True, progressive_outputs=True,
is_seamless=is_seamless,
temperature=temperature, temperature=temperature,
top_k=topk, top_k=topk,
supercondition_factor=supercond, supercondition_factor=supercond,
is_verbose=True is_verbose=True
) )
for image in image_stream: for image in image_stream:
global final_image
final_image = image
label_image_content = PIL.ImageTk.PhotoImage(image) label_image_content = PIL.ImageTk.PhotoImage(image)
label_image.configure(image=label_image_content) label_image.configure(image=label_image_content)
label_image.update() label_image.update()
def save():
final_image.save('out.png')
frm = ttk.Frame(root, padding=16) frm = ttk.Frame(root, padding=16)
frm.grid() frm.grid()
@ -124,8 +135,14 @@ ttk.Label(props, text="Supercondition Factor:").grid(column=0, row=6)
ttk.Entry(props, textvariable=sv_supercond).grid(column=1, row=6) ttk.Entry(props, textvariable=sv_supercond).grid(column=1, row=6)
# #
ttk.Label(props, image=padding_image).grid(column=0, row=7) ttk.Label(props, image=padding_image).grid(column=0, row=7)
# seamless
ttk.Label(props, text="Seamless:").grid(column=0, row=8)
ttk.Checkbutton(props, variable=bv_seamless).grid(column=1, row=8)
#
ttk.Label(props, image=padding_image).grid(column=0, row=9)
# buttons # buttons
ttk.Button(props, text="Generate", command=generate).grid(column=0, row=8) ttk.Button(props, text="Generate", command=generate).grid(column=0, row=10)
ttk.Button(props, text="Quit", command=root.destroy).grid(column=1, row=8) ttk.Button(props, text="Quit", command=root.destroy).grid(column=1, row=10)
ttk.Button(props, text="Save", command=save).grid(column=2, row=10)
root.mainloop() root.mainloop()