|
|
|
@ -1,7 +1,7 @@ |
|
|
|
|
import numpy |
|
|
|
|
from typing import Dict |
|
|
|
|
from torch import LongTensor, FloatTensor |
|
|
|
|
import torch |
|
|
|
|
from torch import Tensor |
|
|
|
|
torch.no_grad() |
|
|
|
|
|
|
|
|
|
from .models.vqgan_detokenizer import VQGanDetokenizer |
|
|
|
@ -15,10 +15,10 @@ from .load_params import ( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_torch( |
|
|
|
|
text_tokens: numpy.ndarray, |
|
|
|
|
text_tokens: LongTensor, |
|
|
|
|
config: dict, |
|
|
|
|
params: dict |
|
|
|
|
) -> Tensor: |
|
|
|
|
) -> FloatTensor: |
|
|
|
|
print("loading torch encoder") |
|
|
|
|
encoder = DalleBartEncoderTorch( |
|
|
|
|
layer_count = config['encoder_layers'], |
|
|
|
@ -37,20 +37,19 @@ def encode_torch( |
|
|
|
|
del encoder_params |
|
|
|
|
|
|
|
|
|
print("encoding text tokens") |
|
|
|
|
text_tokens = torch.tensor(text_tokens).to(torch.long) |
|
|
|
|
encoder_state = encoder(text_tokens) |
|
|
|
|
del encoder |
|
|
|
|
return encoder_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_torch( |
|
|
|
|
text_tokens: Tensor, |
|
|
|
|
encoder_state: Tensor, |
|
|
|
|
text_tokens: LongTensor, |
|
|
|
|
encoder_state: FloatTensor, |
|
|
|
|
config: dict, |
|
|
|
|
seed: int, |
|
|
|
|
params: dict, |
|
|
|
|
image_token_count: int |
|
|
|
|
) -> Tensor: |
|
|
|
|
) -> LongTensor: |
|
|
|
|
print("loading torch decoder") |
|
|
|
|
decoder = DalleBartDecoderTorch( |
|
|
|
|
image_vocab_size = config['image_vocab_size'], |
|
|
|
@ -86,6 +85,9 @@ def generate_image_tokens_torch( |
|
|
|
|
params: dict, |
|
|
|
|
image_token_count: int |
|
|
|
|
) -> numpy.ndarray: |
|
|
|
|
text_tokens = torch.tensor(text_tokens).to(torch.long) |
|
|
|
|
if torch.cuda.is_available(): |
|
|
|
|
text_tokens = text_tokens.cuda() |
|
|
|
|
encoder_state = encode_torch( |
|
|
|
|
text_tokens, |
|
|
|
|
config, |
|
|
|
|