From 9d6b6dcc92e8500c9f18bc131b14a81ecdd914e3 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Tue, 28 Jun 2022 12:54:58 -0400 Subject: [PATCH] previous commit broke flax model, fixed now --- min_dalle/generate_image.py | 17 ++++++++--------- min_dalle/load_params.py | 4 ++-- min_dalle/min_dalle_torch.py | 2 +- min_dalle/models/dalle_bart_decoder_torch.py | 10 +++++----- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/min_dalle/generate_image.py b/min_dalle/generate_image.py index 401f9e2..6829248 100644 --- a/min_dalle/generate_image.py +++ b/min_dalle/generate_image.py @@ -3,6 +3,7 @@ import json import numpy from PIL import Image from typing import Tuple, List +import torch from min_dalle.load_params import load_dalle_bart_flax_params from min_dalle.text_tokenizer import TextTokenizer @@ -53,25 +54,23 @@ def generate_image_from_text( text_tokens = tokenize_text(text, config, vocab, merges) params_dalle_bart = load_dalle_bart_flax_params(model_path) - image_tokens = numpy.zeros(config['image_length']) if is_torch: - image_tokens[:image_token_count] = generate_image_tokens_torch( + image_tokens = generate_image_tokens_torch( text_tokens = text_tokens, seed = seed, config = config, params = params_dalle_bart, image_token_count = image_token_count ) + if image_token_count == config['image_length']: + image = detokenize_torch(image_tokens) + return Image.fromarray(image) else: - image_tokens[...] = generate_image_tokens_flax( + image_tokens = generate_image_tokens_flax( text_tokens = text_tokens, seed = seed, config = config, params = params_dalle_bart, ) - - if image_token_count == config['image_length']: - image = detokenize_torch(image_tokens) - return Image.fromarray(image) - else: - return None \ No newline at end of file + image = detokenize_torch(torch.tensor(image_tokens)) + return Image.fromarray(image) \ No newline at end of file diff --git a/min_dalle/load_params.py b/min_dalle/load_params.py index fa66d89..3fd77a8 100644 --- a/min_dalle/load_params.py +++ b/min_dalle/load_params.py @@ -30,7 +30,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]: for i in P: P[i] = torch.tensor(P[i]) - if torch.cuda.is_available(): P[i] = P[i].cuda() + # if torch.cuda.is_available(): P[i] = P[i].cuda() P['embedding.weight'] = P.pop('quantize.embedding.embedding') @@ -87,7 +87,7 @@ def convert_dalle_bart_torch_from_flax_params( for i in P: P[i] = torch.tensor(P[i]) - if torch.cuda.is_available(): P[i] = P[i].cuda() + # if torch.cuda.is_available(): P[i] = P[i].cuda() for i in list(P): if 'kernel' in i: diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 3940815..690e02c 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -85,7 +85,7 @@ def generate_image_tokens_torch( image_token_count: int ) -> LongTensor: text_tokens = torch.tensor(text_tokens).to(torch.long) - if torch.cuda.is_available(): text_tokens = text_tokens.cuda() + # if torch.cuda.is_available(): text_tokens = text_tokens.cuda() encoder_state = encode_torch( text_tokens, config, diff --git a/min_dalle/models/dalle_bart_decoder_torch.py b/min_dalle/models/dalle_bart_decoder_torch.py index 6b4093d..ec14493 100644 --- a/min_dalle/models/dalle_bart_decoder_torch.py +++ b/min_dalle/models/dalle_bart_decoder_torch.py @@ -127,10 +127,10 @@ class DalleBartDecoderTorch(nn.Module): self.start_token = torch.tensor([start_token]).to(torch.long) self.pad_token = torch.tensor([1]).to(torch.long) self.condition_factor = torch.tensor([10]).to(torch.float) - if torch.cuda.is_available(): - self.start_token = self.start_token.cuda() - self.pad_token = self.pad_token.cuda() - self.condition_factor = self.condition_factor.cuda() + # if torch.cuda.is_available(): + # self.start_token = self.start_token.cuda() + # self.pad_token = self.pad_token.cuda() + # self.condition_factor = self.condition_factor.cuda() self.image_token_count = image_token_count self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count) @@ -204,7 +204,7 @@ class DalleBartDecoderTorch(nn.Module): for i in range(self.sample_token_count): token_index = torch.tensor([i]).to(torch.long) - if torch.cuda.is_available(): token_index = token_index.cuda() + # if torch.cuda.is_available(): token_index = token_index.cuda() probs, keys_values_state = self.decode_step( text_tokens = text_tokens, encoder_state = encoder_state,