use cuda if available
This commit is contained in:
parent
aef24ea157
commit
8544f59576
|
@ -1,7 +1,7 @@
|
|||
import numpy
|
||||
from typing import Dict
|
||||
from torch import LongTensor, FloatTensor
|
||||
import torch
|
||||
from torch import Tensor
|
||||
torch.no_grad()
|
||||
|
||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||
|
@ -15,10 +15,10 @@ from .load_params import (
|
|||
|
||||
|
||||
def encode_torch(
|
||||
text_tokens: numpy.ndarray,
|
||||
text_tokens: LongTensor,
|
||||
config: dict,
|
||||
params: dict
|
||||
) -> Tensor:
|
||||
) -> FloatTensor:
|
||||
print("loading torch encoder")
|
||||
encoder = DalleBartEncoderTorch(
|
||||
layer_count = config['encoder_layers'],
|
||||
|
@ -37,20 +37,19 @@ def encode_torch(
|
|||
del encoder_params
|
||||
|
||||
print("encoding text tokens")
|
||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||
encoder_state = encoder(text_tokens)
|
||||
del encoder
|
||||
return encoder_state
|
||||
|
||||
|
||||
def decode_torch(
|
||||
text_tokens: Tensor,
|
||||
encoder_state: Tensor,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor,
|
||||
config: dict,
|
||||
seed: int,
|
||||
params: dict,
|
||||
image_token_count: int
|
||||
) -> Tensor:
|
||||
) -> LongTensor:
|
||||
print("loading torch decoder")
|
||||
decoder = DalleBartDecoderTorch(
|
||||
image_vocab_size = config['image_vocab_size'],
|
||||
|
@ -86,6 +85,9 @@ def generate_image_tokens_torch(
|
|||
params: dict,
|
||||
image_token_count: int
|
||||
) -> numpy.ndarray:
|
||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||
if torch.cuda.is_available():
|
||||
text_tokens = text_tokens.cuda()
|
||||
encoder_state = encode_torch(
|
||||
text_tokens,
|
||||
config,
|
||||
|
|
Loading…
Reference in New Issue
Block a user