|
|
|
@ -1,7 +1,9 @@ |
|
|
|
|
import os |
|
|
|
|
from PIL import Image |
|
|
|
|
from matplotlib.pyplot import grid |
|
|
|
|
import numpy |
|
|
|
|
from torch import LongTensor |
|
|
|
|
from math import sqrt |
|
|
|
|
import torch |
|
|
|
|
import json |
|
|
|
|
import requests |
|
|
|
@ -142,25 +144,29 @@ class MinDalle: |
|
|
|
|
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def image_from_tokens( |
|
|
|
|
def images_from_tokens( |
|
|
|
|
self, |
|
|
|
|
grid_size: int, |
|
|
|
|
image_tokens: LongTensor, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> Image.Image: |
|
|
|
|
) -> LongTensor: |
|
|
|
|
if not self.is_reusable: del self.decoder |
|
|
|
|
if torch.cuda.is_available(): torch.cuda.empty_cache() |
|
|
|
|
if not self.is_reusable: self.init_detokenizer() |
|
|
|
|
if is_verbose: print("detokenizing image") |
|
|
|
|
images = self.detokenizer.forward(image_tokens).to(torch.uint8) |
|
|
|
|
if not self.is_reusable: del self.detokenizer |
|
|
|
|
return images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def grid_from_images(self, images: LongTensor) -> Image.Image: |
|
|
|
|
grid_size = int(sqrt(images.shape[0])) |
|
|
|
|
images = images.reshape([grid_size] * 2 + list(images.shape[1:])) |
|
|
|
|
image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2) |
|
|
|
|
image = Image.fromarray(image.to('cpu').detach().numpy()) |
|
|
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image_stream( |
|
|
|
|
def generate_images_stream( |
|
|
|
|
self, |
|
|
|
|
text: str, |
|
|
|
|
seed: int, |
|
|
|
@ -169,7 +175,7 @@ class MinDalle: |
|
|
|
|
log2_k: int = 6, |
|
|
|
|
log2_supercondition_factor: int = 3, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> Iterator[Image.Image]: |
|
|
|
|
) -> Iterator[LongTensor]: |
|
|
|
|
assert(log2_mid_count in range(5)) |
|
|
|
|
if is_verbose: print("tokenizing text") |
|
|
|
|
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) |
|
|
|
@ -219,8 +225,53 @@ class MinDalle: |
|
|
|
|
with torch.cuda.amp.autocast(dtype=torch.float32): |
|
|
|
|
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: |
|
|
|
|
tokens = image_tokens[:, 1:] |
|
|
|
|
image = self.image_from_tokens(grid_size, tokens, is_verbose) |
|
|
|
|
yield image |
|
|
|
|
images = self.images_from_tokens(tokens, is_verbose) |
|
|
|
|
yield images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image_stream( |
|
|
|
|
self, |
|
|
|
|
text: str, |
|
|
|
|
seed: int, |
|
|
|
|
grid_size: int, |
|
|
|
|
log2_mid_count: int, |
|
|
|
|
log2_k: int = 6, |
|
|
|
|
log2_supercondition_factor: int = 3, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> Iterator[Image.Image]: |
|
|
|
|
images_stream = self.generate_images_stream( |
|
|
|
|
text, |
|
|
|
|
seed, |
|
|
|
|
grid_size, |
|
|
|
|
log2_mid_count, |
|
|
|
|
log2_k, |
|
|
|
|
log2_supercondition_factor, |
|
|
|
|
is_verbose |
|
|
|
|
) |
|
|
|
|
for images in images_stream: |
|
|
|
|
yield self.grid_from_images(images) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_images( |
|
|
|
|
self, |
|
|
|
|
text: str, |
|
|
|
|
seed: int = -1, |
|
|
|
|
grid_size: int = 1, |
|
|
|
|
log2_k: int = 6, |
|
|
|
|
log2_supercondition_factor: int = 3, |
|
|
|
|
is_verbose: bool = False |
|
|
|
|
) -> LongTensor: |
|
|
|
|
log2_mid_count = 0 |
|
|
|
|
images_stream = self.generate_images_stream( |
|
|
|
|
text, |
|
|
|
|
seed, |
|
|
|
|
grid_size, |
|
|
|
|
log2_mid_count, |
|
|
|
|
log2_k, |
|
|
|
|
log2_supercondition_factor, |
|
|
|
|
is_verbose |
|
|
|
|
) |
|
|
|
|
return next(images_stream) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image( |
|
|
|
|