use cuda if available

This commit is contained in:
Brett Kuprel 2022-06-28 12:38:31 -04:00
parent aef24ea157
commit 8544f59576

View File

@ -1,7 +1,7 @@
import numpy import numpy
from typing import Dict from typing import Dict
from torch import LongTensor, FloatTensor
import torch import torch
from torch import Tensor
torch.no_grad() torch.no_grad()
from .models.vqgan_detokenizer import VQGanDetokenizer from .models.vqgan_detokenizer import VQGanDetokenizer
@ -15,10 +15,10 @@ from .load_params import (
def encode_torch( def encode_torch(
text_tokens: numpy.ndarray, text_tokens: LongTensor,
config: dict, config: dict,
params: dict params: dict
) -> Tensor: ) -> FloatTensor:
print("loading torch encoder") print("loading torch encoder")
encoder = DalleBartEncoderTorch( encoder = DalleBartEncoderTorch(
layer_count = config['encoder_layers'], layer_count = config['encoder_layers'],
@ -37,20 +37,19 @@ def encode_torch(
del encoder_params del encoder_params
print("encoding text tokens") print("encoding text tokens")
text_tokens = torch.tensor(text_tokens).to(torch.long)
encoder_state = encoder(text_tokens) encoder_state = encoder(text_tokens)
del encoder del encoder
return encoder_state return encoder_state
def decode_torch( def decode_torch(
text_tokens: Tensor, text_tokens: LongTensor,
encoder_state: Tensor, encoder_state: FloatTensor,
config: dict, config: dict,
seed: int, seed: int,
params: dict, params: dict,
image_token_count: int image_token_count: int
) -> Tensor: ) -> LongTensor:
print("loading torch decoder") print("loading torch decoder")
decoder = DalleBartDecoderTorch( decoder = DalleBartDecoderTorch(
image_vocab_size = config['image_vocab_size'], image_vocab_size = config['image_vocab_size'],
@ -86,6 +85,9 @@ def generate_image_tokens_torch(
params: dict, params: dict,
image_token_count: int image_token_count: int
) -> numpy.ndarray: ) -> numpy.ndarray:
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available():
text_tokens = text_tokens.cuda()
encoder_state = encode_torch( encoder_state = encode_torch(
text_tokens, text_tokens,
config, config,