|
|
|
@ -73,7 +73,6 @@ def decode_torch( |
|
|
|
|
|
|
|
|
|
print("sampling image tokens") |
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
text_tokens = torch.tensor(text_tokens).to(torch.long) |
|
|
|
|
image_tokens = decoder.forward(text_tokens, encoder_state) |
|
|
|
|
return image_tokens |
|
|
|
|
|
|
|
|
@ -84,10 +83,9 @@ def generate_image_tokens_torch( |
|
|
|
|
config: dict, |
|
|
|
|
params: dict, |
|
|
|
|
image_token_count: int |
|
|
|
|
) -> numpy.ndarray: |
|
|
|
|
) -> LongTensor: |
|
|
|
|
text_tokens = torch.tensor(text_tokens).to(torch.long) |
|
|
|
|
if torch.cuda.is_available(): |
|
|
|
|
text_tokens = text_tokens.cuda() |
|
|
|
|
if torch.cuda.is_available(): text_tokens = text_tokens.cuda() |
|
|
|
|
encoder_state = encode_torch( |
|
|
|
|
text_tokens, |
|
|
|
|
config, |
|
|
|
@ -101,16 +99,15 @@ def generate_image_tokens_torch( |
|
|
|
|
params, |
|
|
|
|
image_token_count |
|
|
|
|
) |
|
|
|
|
return image_tokens.detach().numpy() |
|
|
|
|
return image_tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detokenize_torch(image_tokens: numpy.ndarray) -> numpy.ndarray: |
|
|
|
|
def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray: |
|
|
|
|
print("detokenizing image") |
|
|
|
|
model_path = './pretrained/vqgan' |
|
|
|
|
params = load_vqgan_torch_params(model_path) |
|
|
|
|
detokenizer = VQGanDetokenizer() |
|
|
|
|
detokenizer.load_state_dict(params) |
|
|
|
|
image_tokens = torch.tensor(image_tokens).to(torch.long) |
|
|
|
|
image = detokenizer.forward(image_tokens).to(torch.uint8) |
|
|
|
|
return image.detach().numpy() |
|
|
|
|
|