From 8544f59576a80032df0368302ee35596d91b5459 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Tue, 28 Jun 2022 12:38:31 -0400 Subject: [PATCH] use cuda if available --- min_dalle/min_dalle_torch.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 16ac318..5846e7b 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -1,7 +1,7 @@ import numpy from typing import Dict +from torch import LongTensor, FloatTensor import torch -from torch import Tensor torch.no_grad() from .models.vqgan_detokenizer import VQGanDetokenizer @@ -15,10 +15,10 @@ from .load_params import ( def encode_torch( - text_tokens: numpy.ndarray, + text_tokens: LongTensor, config: dict, params: dict -) -> Tensor: +) -> FloatTensor: print("loading torch encoder") encoder = DalleBartEncoderTorch( layer_count = config['encoder_layers'], @@ -37,20 +37,19 @@ def encode_torch( del encoder_params print("encoding text tokens") - text_tokens = torch.tensor(text_tokens).to(torch.long) encoder_state = encoder(text_tokens) del encoder return encoder_state def decode_torch( - text_tokens: Tensor, - encoder_state: Tensor, + text_tokens: LongTensor, + encoder_state: FloatTensor, config: dict, seed: int, params: dict, image_token_count: int -) -> Tensor: +) -> LongTensor: print("loading torch decoder") decoder = DalleBartDecoderTorch( image_vocab_size = config['image_vocab_size'], @@ -86,6 +85,9 @@ def generate_image_tokens_torch( params: dict, image_token_count: int ) -> numpy.ndarray: + text_tokens = torch.tensor(text_tokens).to(torch.long) + if torch.cuda.is_available(): + text_tokens = text_tokens.cuda() encoder_state = encode_torch( text_tokens, config,