fixed bug with cuda in detokenizer
This commit is contained in:
parent
764b0bc685
commit
1fbb209623
|
@ -6,7 +6,7 @@ This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
Run `sh setup.sh` to install dependencies and download pretrained models. In the bash script, Git LFS is used to download the VQGan detokenizer from Hugging Face and the Weight & Biases python package is used to download the DALL·E Mini and DALL·E Mega transformer models. These models can also be downloaded manually:
|
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually:
|
||||||
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
||||||
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
||||||
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
||||||
|
|
|
@ -63,7 +63,7 @@ def generate_image_from_text(
|
||||||
image_token_count = image_token_count
|
image_token_count = image_token_count
|
||||||
)
|
)
|
||||||
if image_token_count == config['image_length']:
|
if image_token_count == config['image_length']:
|
||||||
image = detokenize_torch(image_tokens)
|
image = detokenize_torch(image_tokens, is_torch=True)
|
||||||
return Image.fromarray(image)
|
return Image.fromarray(image)
|
||||||
else:
|
else:
|
||||||
print(list(image_tokens.to('cpu').detach().numpy()))
|
print(list(image_tokens.to('cpu').detach().numpy()))
|
||||||
|
@ -74,5 +74,5 @@ def generate_image_from_text(
|
||||||
config = config,
|
config = config,
|
||||||
params = params_dalle_bart,
|
params = params_dalle_bart,
|
||||||
)
|
)
|
||||||
image = detokenize_torch(torch.tensor(image_tokens))
|
image = detokenize_torch(torch.tensor(image_tokens), is_torch=False)
|
||||||
return Image.fromarray(image)
|
return Image.fromarray(image)
|
|
@ -104,13 +104,13 @@ def generate_image_tokens_torch(
|
||||||
return image_tokens
|
return image_tokens
|
||||||
|
|
||||||
|
|
||||||
def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray:
|
def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray:
|
||||||
print("detokenizing image")
|
print("detokenizing image")
|
||||||
model_path = './pretrained/vqgan'
|
model_path = './pretrained/vqgan'
|
||||||
params = load_vqgan_torch_params(model_path)
|
params = load_vqgan_torch_params(model_path)
|
||||||
detokenizer = VQGanDetokenizer()
|
detokenizer = VQGanDetokenizer()
|
||||||
detokenizer.load_state_dict(params)
|
detokenizer.load_state_dict(params)
|
||||||
# if torch.cuda.is_available(): detokenizer = detokenizer.cuda()
|
if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda()
|
||||||
image = detokenizer.forward(image_tokens).to(torch.uint8)
|
image = detokenizer.forward(image_tokens).to(torch.uint8)
|
||||||
del detokenizer, params
|
del detokenizer, params
|
||||||
return image.to('cpu').detach().numpy()
|
return image.to('cpu').detach().numpy()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user