|
|
|
@ -45,8 +45,8 @@ class MinDalleTorch(MinDalleBase): |
|
|
|
|
is_encoder=True |
|
|
|
|
) |
|
|
|
|
self.encoder.load_state_dict(params, strict=False) |
|
|
|
|
if torch.cuda.is_available(): self.encoder = self.encoder.cuda() |
|
|
|
|
del params |
|
|
|
|
if torch.cuda.is_available(): self.encoder = self.encoder.cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_decoder(self): |
|
|
|
@ -69,8 +69,8 @@ class MinDalleTorch(MinDalleBase): |
|
|
|
|
is_encoder=False |
|
|
|
|
) |
|
|
|
|
self.decoder.load_state_dict(params, strict=False) |
|
|
|
|
if torch.cuda.is_available(): self.decoder = self.decoder.cuda() |
|
|
|
|
del params |
|
|
|
|
if torch.cuda.is_available(): self.decoder = self.decoder.cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_detokenizer(self): |
|
|
|
|