diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 13cc134..228c601 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -1,8 +1,10 @@ import numpy +import os from typing import Dict from torch import LongTensor, FloatTensor import torch torch.set_grad_enabled(False) +torch.set_num_threads(os.cpu_count()) from .models.vqgan_detokenizer import VQGanDetokenizer from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch @@ -114,4 +116,3 @@ def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray: image = detokenizer.forward(image_tokens).to(torch.uint8) del detokenizer, params return image.to('cpu').detach().numpy() - \ No newline at end of file