use cuda if available
This commit is contained in:
parent
aef24ea157
commit
8544f59576
|
@ -1,7 +1,7 @@
|
||||||
import numpy
|
import numpy
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
from torch import LongTensor, FloatTensor
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
|
||||||
torch.no_grad()
|
torch.no_grad()
|
||||||
|
|
||||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||||
|
@ -15,10 +15,10 @@ from .load_params import (
|
||||||
|
|
||||||
|
|
||||||
def encode_torch(
|
def encode_torch(
|
||||||
text_tokens: numpy.ndarray,
|
text_tokens: LongTensor,
|
||||||
config: dict,
|
config: dict,
|
||||||
params: dict
|
params: dict
|
||||||
) -> Tensor:
|
) -> FloatTensor:
|
||||||
print("loading torch encoder")
|
print("loading torch encoder")
|
||||||
encoder = DalleBartEncoderTorch(
|
encoder = DalleBartEncoderTorch(
|
||||||
layer_count = config['encoder_layers'],
|
layer_count = config['encoder_layers'],
|
||||||
|
@ -37,20 +37,19 @@ def encode_torch(
|
||||||
del encoder_params
|
del encoder_params
|
||||||
|
|
||||||
print("encoding text tokens")
|
print("encoding text tokens")
|
||||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
|
||||||
encoder_state = encoder(text_tokens)
|
encoder_state = encoder(text_tokens)
|
||||||
del encoder
|
del encoder
|
||||||
return encoder_state
|
return encoder_state
|
||||||
|
|
||||||
|
|
||||||
def decode_torch(
|
def decode_torch(
|
||||||
text_tokens: Tensor,
|
text_tokens: LongTensor,
|
||||||
encoder_state: Tensor,
|
encoder_state: FloatTensor,
|
||||||
config: dict,
|
config: dict,
|
||||||
seed: int,
|
seed: int,
|
||||||
params: dict,
|
params: dict,
|
||||||
image_token_count: int
|
image_token_count: int
|
||||||
) -> Tensor:
|
) -> LongTensor:
|
||||||
print("loading torch decoder")
|
print("loading torch decoder")
|
||||||
decoder = DalleBartDecoderTorch(
|
decoder = DalleBartDecoderTorch(
|
||||||
image_vocab_size = config['image_vocab_size'],
|
image_vocab_size = config['image_vocab_size'],
|
||||||
|
@ -86,6 +85,9 @@ def generate_image_tokens_torch(
|
||||||
params: dict,
|
params: dict,
|
||||||
image_token_count: int
|
image_token_count: int
|
||||||
) -> numpy.ndarray:
|
) -> numpy.ndarray:
|
||||||
|
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
text_tokens = text_tokens.cuda()
|
||||||
encoder_state = encode_torch(
|
encoder_state = encode_torch(
|
||||||
text_tokens,
|
text_tokens,
|
||||||
config,
|
config,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user