fixed bug with cuda in detokenizer
This commit is contained in:
@@ -63,7 +63,7 @@ def generate_image_from_text(
|
||||
image_token_count = image_token_count
|
||||
)
|
||||
if image_token_count == config['image_length']:
|
||||
image = detokenize_torch(image_tokens)
|
||||
image = detokenize_torch(image_tokens, is_torch=True)
|
||||
return Image.fromarray(image)
|
||||
else:
|
||||
print(list(image_tokens.to('cpu').detach().numpy()))
|
||||
@@ -74,5 +74,5 @@ def generate_image_from_text(
|
||||
config = config,
|
||||
params = params_dalle_bart,
|
||||
)
|
||||
image = detokenize_torch(torch.tensor(image_tokens))
|
||||
image = detokenize_torch(torch.tensor(image_tokens), is_torch=False)
|
||||
return Image.fromarray(image)
|
@@ -104,13 +104,13 @@ def generate_image_tokens_torch(
|
||||
return image_tokens
|
||||
|
||||
|
||||
def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray:
|
||||
def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray:
|
||||
print("detokenizing image")
|
||||
model_path = './pretrained/vqgan'
|
||||
params = load_vqgan_torch_params(model_path)
|
||||
detokenizer = VQGanDetokenizer()
|
||||
detokenizer.load_state_dict(params)
|
||||
# if torch.cuda.is_available(): detokenizer = detokenizer.cuda()
|
||||
if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda()
|
||||
image = detokenizer.forward(image_tokens).to(torch.uint8)
|
||||
del detokenizer, params
|
||||
return image.to('cpu').detach().numpy()
|
||||
|
Reference in New Issue
Block a user