use cuda if available

This commit is contained in:
Brett Kuprel 2022-06-28 12:38:31 -04:00
parent aef24ea157
commit 8544f59576

View File

@ -1,7 +1,7 @@
import numpy
from typing import Dict
from torch import LongTensor, FloatTensor
import torch
from torch import Tensor
torch.no_grad()
from .models.vqgan_detokenizer import VQGanDetokenizer
@ -15,10 +15,10 @@ from .load_params import (
def encode_torch(
text_tokens: numpy.ndarray,
text_tokens: LongTensor,
config: dict,
params: dict
) -> Tensor:
) -> FloatTensor:
print("loading torch encoder")
encoder = DalleBartEncoderTorch(
layer_count = config['encoder_layers'],
@ -37,20 +37,19 @@ def encode_torch(
del encoder_params
print("encoding text tokens")
text_tokens = torch.tensor(text_tokens).to(torch.long)
encoder_state = encoder(text_tokens)
del encoder
return encoder_state
def decode_torch(
text_tokens: Tensor,
encoder_state: Tensor,
text_tokens: LongTensor,
encoder_state: FloatTensor,
config: dict,
seed: int,
params: dict,
image_token_count: int
) -> Tensor:
) -> LongTensor:
print("loading torch decoder")
decoder = DalleBartDecoderTorch(
image_vocab_size = config['image_vocab_size'],
@ -86,6 +85,9 @@ def generate_image_tokens_torch(
params: dict,
image_token_count: int
) -> numpy.ndarray:
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available():
text_tokens = text_tokens.cuda()
encoder_state = encode_torch(
text_tokens,
config,