|
|
|
@ -6,13 +6,11 @@ import torch |
|
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
torch.set_num_threads(os.cpu_count()) |
|
|
|
|
|
|
|
|
|
from .load_params import ( |
|
|
|
|
convert_and_save_torch_params, |
|
|
|
|
load_dalle_bart_flax_params |
|
|
|
|
) |
|
|
|
|
from .load_params import convert_and_save_torch_params |
|
|
|
|
from .min_dalle_base import MinDalleBase |
|
|
|
|
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch |
|
|
|
|
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch |
|
|
|
|
from .models.vqgan_detokenizer import VQGanDetokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MinDalleTorch(MinDalleBase): |
|
|
|
@ -26,15 +24,14 @@ class MinDalleTorch(MinDalleBase): |
|
|
|
|
super().__init__(is_mega) |
|
|
|
|
self.is_reusable = is_reusable |
|
|
|
|
self.token_count = token_count |
|
|
|
|
|
|
|
|
|
if not is_mega: |
|
|
|
|
self.model_params = load_dalle_bart_flax_params(self.model_path) |
|
|
|
|
|
|
|
|
|
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt') |
|
|
|
|
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt') |
|
|
|
|
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detokenizer.pt') |
|
|
|
|
|
|
|
|
|
is_converted = os.path.exists(self.encoder_params_path) |
|
|
|
|
is_converted &= os.path.exists(self.decoder_params_path) |
|
|
|
|
is_converted &= os.path.exists(self.detoker_params_path) |
|
|
|
|
if not is_converted: |
|
|
|
|
convert_and_save_torch_params(is_mega, self.model_path) |
|
|
|
|
|
|
|
|
@ -79,11 +76,14 @@ class MinDalleTorch(MinDalleBase): |
|
|
|
|
del params |
|
|
|
|
if torch.cuda.is_available(): self.decoder = self.decoder.cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_detokenizer(self): |
|
|
|
|
super().init_detokenizer() |
|
|
|
|
if torch.cuda.is_available(): |
|
|
|
|
self.detokenizer = self.detokenizer.cuda() |
|
|
|
|
print("initializing VQGanDetokenizer") |
|
|
|
|
self.detokenizer = VQGanDetokenizer() |
|
|
|
|
params = torch.load(self.detoker_params_path) |
|
|
|
|
self.detokenizer.load_state_dict(params) |
|
|
|
|
del params |
|
|
|
|
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image_tokens(self, text: str, seed: int) -> LongTensor: |
|
|
|
|