diff --git a/min_dalle/load_params.py b/min_dalle/load_params.py index ef081d8..fa66d89 100644 --- a/min_dalle/load_params.py +++ b/min_dalle/load_params.py @@ -2,8 +2,9 @@ import os import numpy from copy import deepcopy from typing import Dict -import torch from flax import traverse_util, serialization +import torch +torch.no_grad() def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]: @@ -29,6 +30,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]: for i in P: P[i] = torch.tensor(P[i]) + if torch.cuda.is_available(): P[i] = P[i].cuda() P['embedding.weight'] = P.pop('quantize.embedding.embedding') @@ -85,6 +87,7 @@ def convert_dalle_bart_torch_from_flax_params( for i in P: P[i] = torch.tensor(P[i]) + if torch.cuda.is_available(): P[i] = P[i].cuda() for i in list(P): if 'kernel' in i: diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 5846e7b..3940815 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -73,7 +73,6 @@ def decode_torch( print("sampling image tokens") torch.manual_seed(seed) - text_tokens = torch.tensor(text_tokens).to(torch.long) image_tokens = decoder.forward(text_tokens, encoder_state) return image_tokens @@ -84,10 +83,9 @@ def generate_image_tokens_torch( config: dict, params: dict, image_token_count: int -) -> numpy.ndarray: +) -> LongTensor: text_tokens = torch.tensor(text_tokens).to(torch.long) - if torch.cuda.is_available(): - text_tokens = text_tokens.cuda() + if torch.cuda.is_available(): text_tokens = text_tokens.cuda() encoder_state = encode_torch( text_tokens, config, @@ -101,16 +99,15 @@ def generate_image_tokens_torch( params, image_token_count ) - return image_tokens.detach().numpy() + return image_tokens -def detokenize_torch(image_tokens: numpy.ndarray) -> numpy.ndarray: +def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray: print("detokenizing image") model_path = './pretrained/vqgan' params = load_vqgan_torch_params(model_path) detokenizer = VQGanDetokenizer() detokenizer.load_state_dict(params) - image_tokens = torch.tensor(image_tokens).to(torch.long) image = detokenizer.forward(image_tokens).to(torch.uint8) return image.detach().numpy() \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_decoder_torch.py b/min_dalle/models/dalle_bart_decoder_torch.py index 5814041..6b4093d 100644 --- a/min_dalle/models/dalle_bart_decoder_torch.py +++ b/min_dalle/models/dalle_bart_decoder_torch.py @@ -127,6 +127,10 @@ class DalleBartDecoderTorch(nn.Module): self.start_token = torch.tensor([start_token]).to(torch.long) self.pad_token = torch.tensor([1]).to(torch.long) self.condition_factor = torch.tensor([10]).to(torch.float) + if torch.cuda.is_available(): + self.start_token = self.start_token.cuda() + self.pad_token = self.pad_token.cuda() + self.condition_factor = self.condition_factor.cuda() self.image_token_count = image_token_count self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count) @@ -200,6 +204,7 @@ class DalleBartDecoderTorch(nn.Module): for i in range(self.sample_token_count): token_index = torch.tensor([i]).to(torch.long) + if torch.cuda.is_available(): token_index = token_index.cuda() probs, keys_values_state = self.decode_step( text_tokens = text_tokens, encoder_state = encoder_state,