@ -1,10 +1,11 @@
import os
from PIL import Image
import numpy
from torch import LongTensor
from torch import LongTensor , FloatTensor
import torch
import json
import requests
from typing import Callable , Tuple
torch . set_grad_enabled ( False )
torch . set_num_threads ( os . cpu_count ( ) )
@ -26,7 +27,6 @@ class MinDalle:
self . is_reusable = is_reusable
self . is_verbose = is_verbose
self . text_token_count = 64
self . image_token_count = 256
self . layer_count = 24 if is_mega else 12
self . attention_head_count = 32 if is_mega else 16
self . embed_count = 2048 if is_mega else 1024
@ -91,7 +91,7 @@ class MinDalle:
vocab = json . load ( f )
with open ( self . merges_path , ' r ' , encoding = ' utf8 ' ) as f :
merges = f . read ( ) . split ( " \n " ) [ 1 : - 1 ]
self . tokenizer = TextTokenizer ( vocab , merges , is_verbose = self . is_verbose )
self . tokenizer = TextTokenizer ( vocab , merges )
def init_encoder ( self ) :
@ -117,7 +117,6 @@ class MinDalle:
if not is_downloaded : self . download_decoder ( )
if self . is_verbose : print ( " initializing DalleBartDecoder " )
self . decoder = DalleBartDecoder (
image_token_count = self . image_token_count ,
image_vocab_count = self . image_vocab_count ,
attention_head_count = self . attention_head_count ,
embed_count = self . embed_count ,
@ -142,16 +141,37 @@ class MinDalle:
if torch . cuda . is_available ( ) : self . detokenizer = self . detokenizer . cuda ( )
def image_from_tokens (
self ,
grid_size : int ,
image_tokens : LongTensor ,
is_verbose : bool = False
) - > Image . Image :
if not self . is_reusable : del self . decoder
if torch . cuda . is_available ( ) : torch . cuda . empty_cache ( )
if not self . is_reusable : self . init_detokenizer ( )
if is_verbose : print ( " detokenizing image " )
images = self . detokenizer . forward ( image_tokens ) . to ( torch . uint8 )
if not self . is_reusable : del self . detokenizer
images = images . reshape ( [ grid_size ] * 2 + list ( images . shape [ 1 : ] ) )
image = images . flatten ( 1 , 2 ) . transpose ( 0 , 1 ) . flatten ( 1 , 2 )
image = Image . fromarray ( image . to ( ' cpu ' ) . detach ( ) . numpy ( ) )
return image
def generate_image_tokens (
self ,
text : str ,
seed : int ,
image_count : int ,
row_count : int
grid_size : int ,
row_count : int ,
mid_count : int = None ,
handle_intermediate_image : Callable [ [ int , Image . Image ] , None ] = None ,
is_verbose : bool = False
) - > LongTensor :
if self . is_verbose : print ( " tokenizing text " )
tokens = self . tokenizer . tokenize ( text )
if self . is_verbose : print ( " text tokens " , tokens )
if is_verbose : print ( " tokenizing text " )
tokens = self . tokenizer . tokenize ( text , is_verbose = is_verbose )
if is_verbose : print ( " text tokens " , tokens )
text_tokens = numpy . ones ( ( 2 , 64 ) , dtype = numpy . int32 )
text_tokens [ 0 , : 2 ] = [ tokens [ 0 ] , tokens [ - 1 ] ]
text_tokens [ 1 , : len ( tokens ) ] = tokens
@ -160,40 +180,57 @@ class MinDalle:
if torch . cuda . is_available ( ) : text_tokens = text_tokens . cuda ( )
if not self . is_reusable : self . init_encoder ( )
if self . is_verbose : print ( " encoding text tokens " )
if is_verbose : print ( " encoding text tokens " )
encoder_state = self . encoder . forward ( text_tokens )
if not self . is_reusable : del self . encoder
if torch . cuda . is_available ( ) : torch . cuda . empty_cache ( )
if not self . is_reusable : self . init_decoder ( )
if self . is_verbose : print ( " sampling image tokens " )
if seed > 0 : torch . manual_seed ( seed )
image_tokens = self . decoder . forward (
image_count ,
row_count ,
text_tokens ,
encoder_state
encoder_state , attention_mask , attention_state , image_tokens = (
self . decoder . decode_initial (
seed ,
grid_size * * 2 ,
text_tokens ,
encoder_state
)
)
if not self . is_reusable : del self . decoder
return image_tokens
for row_index in range ( row_count ) :
if is_verbose :
print ( ' sampling row {} of {} ' . format ( row_index + 1 , row_count ) )
attention_state , image_tokens = self . decoder . decode_row (
row_index ,
encoder_state ,
attention_mask ,
attention_state ,
image_tokens
)
if mid_count is not None :
if ( ( row_index + 1 ) * mid_count ) % row_count == 0 :
tokens = image_tokens [ : , 1 : ]
image = self . image_from_tokens ( grid_size , tokens , is_verbose )
handle_intermediate_image ( row_index , image )
return image_tokens [ : , 1 : ]
def generate_image (
self ,
text : str ,
seed : int = - 1 ,
grid_size : int = 1
grid_size : int = 1 ,
mid_count : int = None ,
handle_intermediate_image : Callable [ [ Image . Image ] , None ] = None ,
is_verbose : bool = False
) - > Image . Image :
image_count = grid_size * * 2
row_count = 16
image_tokens = self . generate_image_tokens ( text , seed , image_count , row_count )
if torch . cuda . is_available ( ) : torch . cuda . empty_cache ( )
if not self . is_reusable : self . init_detokenizer ( )
if self . is_verbose : print ( " detokenizing image " )
images = self . detokenizer . forward ( image_tokens ) . to ( torch . uint8 )
if not self . is_reusable : del self . detokenizer
images = images . reshape ( [ grid_size ] * 2 + list ( images . shape [ 1 : ] ) )
image = images . flatten ( 1 , 2 ) . transpose ( 0 , 1 ) . flatten ( 1 , 2 )
image = Image . fromarray ( image . to ( ' cpu ' ) . detach ( ) . numpy ( ) )
if torch . cuda . is_available ( ) : torch . cuda . empty_cache ( )
return image
image_tokens = self . generate_image_tokens (
text ,
seed ,
grid_size ,
row_count = 16 ,
mid_count = mid_count ,
handle_intermediate_image = handle_intermediate_image ,
is_verbose = is_verbose
)
return self . image_from_tokens ( grid_size , image_tokens , is_verbose )