diff --git a/image_from_text.py b/image_from_text.py index a0a3bdf..11878a2 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -15,7 +15,7 @@ parser.set_defaults(torch=False) parser.add_argument('--text', type=str, default='alien life') parser.add_argument('--seed', type=int, default=7) parser.add_argument('--image_path', type=str, default='generated') -parser.add_argument('--sample_token_count', type=int, default=256) # for debugging +parser.add_argument('--token_count', type=int, default=256) # for debugging def ascii_from_image(image: Image.Image, size: int) -> str: @@ -42,20 +42,21 @@ def generate_image( text: str, seed: int, image_path: str, - sample_token_count: int + token_count: int ): + is_expendable = True if is_torch: - image_generator = MinDalleTorch(is_mega, sample_token_count) - image_tokens = image_generator.generate_image_tokens(text, seed) + image_generator = MinDalleTorch(is_mega, is_expendable, token_count) - if sample_token_count < image_generator.config['image_length']: + if token_count < image_generator.config['image_length']: + image_tokens = image_generator.generate_image_tokens(text, seed) print('image tokens', list(image_tokens.to('cpu').detach().numpy())) return else: image = image_generator.generate_image(text, seed) else: - image_generator = MinDalleFlax(is_mega) + image_generator = MinDalleFlax(is_mega, is_expendable=True) image = image_generator.generate_image(text, seed) save_image(image, image_path) @@ -71,5 +72,5 @@ if __name__ == '__main__': text=args.text, seed=args.seed, image_path=args.image_path, - sample_token_count=args.sample_token_count + token_count=args.token_count ) \ No newline at end of file diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle_base.py similarity index 85% rename from min_dalle/min_dalle.py rename to min_dalle/min_dalle_base.py index adaa6f7..aa7cd22 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle_base.py @@ -6,7 +6,7 @@ from .text_tokenizer import TextTokenizer from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params from .models.vqgan_detokenizer import VQGanDetokenizer -class MinDalle: +class MinDalleBase: def __init__(self, is_mega: bool): self.is_mega = is_mega model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') @@ -25,11 +25,15 @@ class MinDalle: merges = f.read().split("\n")[1:-1] self.model_params = load_dalle_bart_flax_params(model_path) - self.tokenizer = TextTokenizer(vocab, merges) + + + def init_detokenizer(self): + print("initializing VQGanDetokenizer") + params = load_vqgan_torch_params('./pretrained/vqgan') self.detokenizer = VQGanDetokenizer() - vqgan_params = load_vqgan_torch_params('./pretrained/vqgan') - self.detokenizer.load_state_dict(vqgan_params) + self.detokenizer.load_state_dict(params) + del params def tokenize_text(self, text: str) -> numpy.ndarray: diff --git a/min_dalle/min_dalle_flax.py b/min_dalle/min_dalle_flax.py index 100d4ab..b60b538 100644 --- a/min_dalle/min_dalle_flax.py +++ b/min_dalle/min_dalle_flax.py @@ -3,18 +3,25 @@ import numpy from PIL import Image import torch -from .min_dalle import MinDalle +from .min_dalle_base import MinDalleBase from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax -class MinDalleFlax(MinDalle): - def __init__(self, is_mega: bool): +class MinDalleFlax(MinDalleBase): + def __init__(self, is_mega: bool, is_expendable: bool = False): super().__init__(is_mega) + self.is_expendable = is_expendable print("initializing MinDalleFlax") + if not is_expendable: + self.init_encoder() + self.init_decoder() + self.init_detokenizer() - print("loading encoder") - self.encoder = DalleBartEncoderFlax( + + def init_encoder(self): + print("initializing DalleBartEncoderFlax") + self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax( attention_head_count = self.config['encoder_attention_heads'], embed_count = self.config['d_model'], glu_embed_count = self.config['encoder_ffn_dim'], @@ -23,7 +30,9 @@ class MinDalleFlax(MinDalle): layer_count = self.config['encoder_layers'] ).bind({'params': self.model_params.pop('encoder')}) - print("loading decoder") + + def init_decoder(self): + print("initializing DalleBartDecoderFlax") self.decoder = DalleBartDecoderFlax( image_token_count = self.config['image_length'], text_token_count = self.config['max_text_length'], @@ -39,20 +48,30 @@ class MinDalleFlax(MinDalle): def generate_image(self, text: str, seed: int) -> Image.Image: text_tokens = self.tokenize_text(text) + if self.is_expendable: self.init_encoder() print("encoding text tokens") encoder_state = self.encoder(text_tokens) + if self.is_expendable: del self.encoder + if self.is_expendable: + self.init_decoder() + params = self.model_params.pop('decoder') + else: + params = self.model_params['decoder'] print("sampling image tokens") image_tokens = self.decoder.sample_image_tokens( text_tokens, encoder_state, jax.random.PRNGKey(seed), - self.model_params['decoder'] + params ) + if self.is_expendable: del self.decoder image_tokens = torch.tensor(numpy.array(image_tokens)) + if self.is_expendable: self.init_detokenizer() print("detokenizing image") image = self.detokenizer.forward(image_tokens).to(torch.uint8) + if self.is_expendable: del self.detokenizer image = Image.fromarray(image.to('cpu').detach().numpy()) return image \ No newline at end of file diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 6bf71af..2dd21cd 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -9,17 +9,30 @@ torch.set_grad_enabled(False) torch.set_num_threads(os.cpu_count()) from .load_params import convert_dalle_bart_torch_from_flax_params -from .min_dalle import MinDalle +from .min_dalle_base import MinDalleBase from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch -class MinDalleTorch(MinDalle): - def __init__(self, is_mega: bool, sample_token_count: int = 256): +class MinDalleTorch(MinDalleBase): + def __init__( + self, + is_mega: bool, + is_expendable: bool = False, + token_count: int = 256 + ): super().__init__(is_mega) + self.is_expendable = is_expendable + self.token_count = token_count print("initializing MinDalleTorch") + if not is_expendable: + self.init_encoder() + self.init_decoder() + self.init_detokenizer() - print("loading encoder") + + def init_encoder(self): + print("initializing DalleBartEncoderTorch") self.encoder = DalleBartEncoderTorch( layer_count = self.config['encoder_layers'], embed_count = self.config['d_model'], @@ -28,18 +41,22 @@ class MinDalleTorch(MinDalle): text_token_count = self.config['max_text_length'], glu_embed_count = self.config['encoder_ffn_dim'] ) - encoder_params = convert_dalle_bart_torch_from_flax_params( + params = convert_dalle_bart_torch_from_flax_params( self.model_params.pop('encoder'), layer_count=self.config['encoder_layers'], is_encoder=True ) - self.encoder.load_state_dict(encoder_params, strict=False) + self.encoder.load_state_dict(params, strict=False) + if torch.cuda.is_available(): self.encoder = self.encoder.cuda() + del params + - print("loading decoder") + def init_decoder(self): + print("initializing DalleBartDecoderTorch") self.decoder = DalleBartDecoderTorch( image_vocab_size = self.config['image_vocab_size'], image_token_count = self.config['image_length'], - sample_token_count = sample_token_count, + sample_token_count = self.token_count, embed_count = self.config['d_model'], attention_head_count = self.config['decoder_attention_heads'], glu_embed_count = self.config['decoder_ffn_dim'], @@ -48,36 +65,45 @@ class MinDalleTorch(MinDalle): start_token = self.config['decoder_start_token_id'], is_verbose = True ) - decoder_params = convert_dalle_bart_torch_from_flax_params( + params = convert_dalle_bart_torch_from_flax_params( self.model_params.pop('decoder'), layer_count=self.config['decoder_layers'], is_encoder=False ) - self.decoder.load_state_dict(decoder_params, strict=False) + self.decoder.load_state_dict(params, strict=False) + if torch.cuda.is_available(): self.decoder = self.decoder.cuda() + del params + + def init_detokenizer(self): + super().init_detokenizer() if torch.cuda.is_available(): - self.encoder = self.encoder.cuda() - self.decoder = self.decoder.cuda() self.detokenizer = self.detokenizer.cuda() - + def generate_image_tokens(self, text: str, seed: int) -> LongTensor: text_tokens = self.tokenize_text(text) text_tokens = torch.tensor(text_tokens).to(torch.long) if torch.cuda.is_available(): text_tokens = text_tokens.cuda() + if self.is_expendable: self.init_encoder() print("encoding text tokens") encoder_state = self.encoder.forward(text_tokens) + if self.is_expendable: del self.encoder + if self.is_expendable: self.init_decoder() print("sampling image tokens") torch.manual_seed(seed) image_tokens = self.decoder.forward(text_tokens, encoder_state) + if self.is_expendable: del self.decoder return image_tokens def generate_image(self, text: str, seed: int) -> Image.Image: image_tokens = self.generate_image_tokens(text, seed) + if self.is_expendable: self.init_detokenizer() print("detokenizing image") image = self.detokenizer.forward(image_tokens).to(torch.uint8) + if self.is_expendable: del self.detokenizer image = Image.fromarray(image.to('cpu').detach().numpy()) return image \ No newline at end of file