From 1fbb209623a512bd954fc336443f3ae0dd0de2d0 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Tue, 28 Jun 2022 22:02:35 -0400 Subject: [PATCH] fixed bug with cuda in detokenizer --- README.md | 2 +- min_dalle/generate_image.py | 4 ++-- min_dalle/min_dalle_torch.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 81f467d..a268946 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma ### Setup -Run `sh setup.sh` to install dependencies and download pretrained models. In the bash script, Git LFS is used to download the VQGan detokenizer from Hugging Face and the Weight & Biases python package is used to download the DALL·E Mini and DALL·E Mega transformer models. These models can also be downloaded manually: +Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually: [VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384), [DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files), [DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files) diff --git a/min_dalle/generate_image.py b/min_dalle/generate_image.py index ba7fd55..f7f63cb 100644 --- a/min_dalle/generate_image.py +++ b/min_dalle/generate_image.py @@ -63,7 +63,7 @@ def generate_image_from_text( image_token_count = image_token_count ) if image_token_count == config['image_length']: - image = detokenize_torch(image_tokens) + image = detokenize_torch(image_tokens, is_torch=True) return Image.fromarray(image) else: print(list(image_tokens.to('cpu').detach().numpy())) @@ -74,5 +74,5 @@ def generate_image_from_text( config = config, params = params_dalle_bart, ) - image = detokenize_torch(torch.tensor(image_tokens)) + image = detokenize_torch(torch.tensor(image_tokens), is_torch=False) return Image.fromarray(image) \ No newline at end of file diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 215181b..13cc134 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -104,13 +104,13 @@ def generate_image_tokens_torch( return image_tokens -def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray: +def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray: print("detokenizing image") model_path = './pretrained/vqgan' params = load_vqgan_torch_params(model_path) detokenizer = VQGanDetokenizer() detokenizer.load_state_dict(params) - # if torch.cuda.is_available(): detokenizer = detokenizer.cuda() + if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda() image = detokenizer.forward(image_tokens).to(torch.uint8) del detokenizer, params return image.to('cpu').detach().numpy()