generate_images_stream and generate_images

This commit is contained in:
Brett Kuprel 2022-07-07 17:03:47 -04:00
parent b17bea11b6
commit 2cac9220b5
5 changed files with 62 additions and 111 deletions

102
README.rst vendored
View File

@ -1,102 +0,0 @@
min(DALL·E)
===========
|Colab|   |Replicate|   |Discord|
This is a fast, minimal port of Boris Daymas `DALL·E
Mega <https://github.com/borisdayma/dalle-mini>`__. It has been stripped
down for inference and converted to PyTorch. The only third party
dependencies are numpy, requests, pillow and torch.
To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a
T4 in Colab - 48 sec with a P100 in Colab - 14 sec with an A100 on
Replicate
The flax model and code for converting it to torch can be found
`here <https://github.com/kuprel/min-dalle-flax>`__.
Install
-------
.. code:: bash
$ pip install min-dalle
Usage
-----
Load the model parameters once and reuse the model to generate multiple
images.
.. code:: python
from min_dalle import MinDalle
model = MinDalle(
is_mega=True,
is_reusable=True,
models_root='./pretrained'
)
The required models will be downloaded to ``models_root`` if they are
not already there. Once everything has finished initializing, call
``generate_image`` with some text as many times as you want. Use a
positive ``seed`` for reproducible results. Higher values for
``log2_supercondition_factor`` result in better agreement with the text
but a narrower variety of generated images. Every image token is sampled
from the top-:math:`k` most probable tokens.
.. code:: python
image = model.generate_image(
text='Nuclear explosion broccoli',
seed=-1,
grid_size=4,
log2_k=6,
log2_supercondition_factor=5,
is_verbose=False
)
display(image)
Interactive
~~~~~~~~~~~
If the model is being used interactively (e.g. in a notebook)
``generate_image_stream`` can be used to generate a stream of images as
the model is decoding. The detokenizer adds a slight delay for each
image. Setting ``log2_mid_count`` to 3 results in a total of
``2 ** 3 = 8`` generated images. The only valid values for
``log2_mid_count`` are 0, 1, 2, 3, and 4. This is implemented in the
colab.
.. code:: python
image_stream = model.generate_image_stream(
text='Dali painting of WALL·E',
seed=-1,
grid_size=3,
log2_mid_count=3,
log2_k=6,
log2_supercondition_factor=3,
is_verbose=False
)
for image in image_stream:
display(image)
Command Line
~~~~~~~~~~~~
Use ``image_from_text.py`` to generate images from the command line.
.. code:: bash
$ python image_from_text.py --text='artificial intelligence' --no-mega
.. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb
.. |Replicate| image:: https://replicate.com/kuprel/min-dalle/badge
:target: https://replicate.com/kuprel/min-dalle
.. |Discord| image:: https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white
:target: https://discord.com/channels/823813159592001537/912729332311556136

2
cog.yaml vendored
View File

@ -6,7 +6,7 @@ build:
- "libgl1-mesa-glx" - "libgl1-mesa-glx"
- "libglib2.0-0" - "libglib2.0-0"
python_packages: python_packages:
- "min-dalle==0.2.36" - "min-dalle==0.3.1"
run: run:
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html - pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html

View File

@ -1,7 +1,9 @@
import os import os
from PIL import Image from PIL import Image
from matplotlib.pyplot import grid
import numpy import numpy
from torch import LongTensor from torch import LongTensor
from math import sqrt
import torch import torch
import json import json
import requests import requests
@ -142,25 +144,29 @@ class MinDalle:
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
def image_from_tokens( def images_from_tokens(
self, self,
grid_size: int,
image_tokens: LongTensor, image_tokens: LongTensor,
is_verbose: bool = False is_verbose: bool = False
) -> Image.Image: ) -> LongTensor:
if not self.is_reusable: del self.decoder if not self.is_reusable: del self.decoder
if torch.cuda.is_available(): torch.cuda.empty_cache() if torch.cuda.is_available(): torch.cuda.empty_cache()
if not self.is_reusable: self.init_detokenizer() if not self.is_reusable: self.init_detokenizer()
if is_verbose: print("detokenizing image") if is_verbose: print("detokenizing image")
images = self.detokenizer.forward(image_tokens).to(torch.uint8) images = self.detokenizer.forward(image_tokens).to(torch.uint8)
if not self.is_reusable: del self.detokenizer if not self.is_reusable: del self.detokenizer
return images
def grid_from_images(self, images: LongTensor) -> Image.Image:
grid_size = int(sqrt(images.shape[0]))
images = images.reshape([grid_size] * 2 + list(images.shape[1:])) images = images.reshape([grid_size] * 2 + list(images.shape[1:]))
image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2) image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2)
image = Image.fromarray(image.to('cpu').detach().numpy()) image = Image.fromarray(image.to('cpu').detach().numpy())
return image return image
def generate_image_stream( def generate_images_stream(
self, self,
text: str, text: str,
seed: int, seed: int,
@ -169,7 +175,7 @@ class MinDalle:
log2_k: int = 6, log2_k: int = 6,
log2_supercondition_factor: int = 3, log2_supercondition_factor: int = 3,
is_verbose: bool = False is_verbose: bool = False
) -> Iterator[Image.Image]: ) -> Iterator[LongTensor]:
assert(log2_mid_count in range(5)) assert(log2_mid_count in range(5))
if is_verbose: print("tokenizing text") if is_verbose: print("tokenizing text")
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
@ -219,8 +225,53 @@ class MinDalle:
with torch.cuda.amp.autocast(dtype=torch.float32): with torch.cuda.amp.autocast(dtype=torch.float32):
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
tokens = image_tokens[:, 1:] tokens = image_tokens[:, 1:]
image = self.image_from_tokens(grid_size, tokens, is_verbose) images = self.images_from_tokens(tokens, is_verbose)
yield image yield images
def generate_image_stream(
self,
text: str,
seed: int,
grid_size: int,
log2_mid_count: int,
log2_k: int = 6,
log2_supercondition_factor: int = 3,
is_verbose: bool = False
) -> Iterator[Image.Image]:
images_stream = self.generate_images_stream(
text,
seed,
grid_size,
log2_mid_count,
log2_k,
log2_supercondition_factor,
is_verbose
)
for images in images_stream:
yield self.grid_from_images(images)
def generate_images(
self,
text: str,
seed: int = -1,
grid_size: int = 1,
log2_k: int = 6,
log2_supercondition_factor: int = 3,
is_verbose: bool = False
) -> LongTensor:
log2_mid_count = 0
images_stream = self.generate_images_stream(
text,
seed,
grid_size,
log2_mid_count,
log2_k,
log2_supercondition_factor,
is_verbose
)
return next(images_stream)
def generate_image( def generate_image(

View File

@ -1,5 +1,6 @@
from min_dalle import MinDalle from min_dalle import MinDalle
import tempfile import tempfile
import torch
from typing import Iterator from typing import Iterator
from cog import BasePredictor, Path, Input from cog import BasePredictor, Path, Input
@ -53,5 +54,6 @@ class ReplicatePredictor(BasePredictor):
except: except:
print("An error occured, deleting model") print("An error occured, deleting model")
del self.model del self.model
torch.cuda.empty_cache()
self.setup() self.setup()
raise Exception("There was an error, please try again") raise Exception("There was an error, please try again")

View File

@ -5,7 +5,7 @@ setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
# long_description=(Path(__file__).parent / "README.rst").read_text(), # long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.2.36', version='0.3.1',
author='Brett Kuprel', author='Brett Kuprel',
author_email='brkuprel@gmail.com', author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle', url='https://github.com/kuprel/min-dalle',