|
|
|
@ -16,14 +16,14 @@ class MinDalleTorch(MinDalleBase): |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
is_mega: bool, |
|
|
|
|
is_expendable: bool = False, |
|
|
|
|
is_reusable: bool = True, |
|
|
|
|
token_count: int = 256 |
|
|
|
|
): |
|
|
|
|
super().__init__(is_mega) |
|
|
|
|
self.is_expendable = is_expendable |
|
|
|
|
self.is_reusable = is_reusable |
|
|
|
|
self.token_count = token_count |
|
|
|
|
print("initializing MinDalleTorch") |
|
|
|
|
if not is_expendable: |
|
|
|
|
if is_reusable: |
|
|
|
|
self.init_encoder() |
|
|
|
|
self.init_decoder() |
|
|
|
|
self.init_detokenizer() |
|
|
|
@ -84,24 +84,24 @@ class MinDalleTorch(MinDalleBase): |
|
|
|
|
text_tokens = torch.tensor(text_tokens).to(torch.long) |
|
|
|
|
if torch.cuda.is_available(): text_tokens = text_tokens.cuda() |
|
|
|
|
|
|
|
|
|
if self.is_expendable: self.init_encoder() |
|
|
|
|
if not self.is_reusable: self.init_encoder() |
|
|
|
|
print("encoding text tokens") |
|
|
|
|
encoder_state = self.encoder.forward(text_tokens) |
|
|
|
|
if self.is_expendable: del self.encoder |
|
|
|
|
if not self.is_reusable: del self.encoder |
|
|
|
|
|
|
|
|
|
if self.is_expendable: self.init_decoder() |
|
|
|
|
if not self.is_reusable: self.init_decoder() |
|
|
|
|
print("sampling image tokens") |
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
image_tokens = self.decoder.forward(text_tokens, encoder_state) |
|
|
|
|
if self.is_expendable: del self.decoder |
|
|
|
|
if not self.is_reusable: del self.decoder |
|
|
|
|
return image_tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image(self, text: str, seed: int) -> Image.Image: |
|
|
|
|
image_tokens = self.generate_image_tokens(text, seed) |
|
|
|
|
if self.is_expendable: self.init_detokenizer() |
|
|
|
|
if not self.is_reusable: self.init_detokenizer() |
|
|
|
|
print("detokenizing image") |
|
|
|
|
image = self.detokenizer.forward(image_tokens).to(torch.uint8) |
|
|
|
|
if self.is_expendable: del self.detokenizer |
|
|
|
|
if not self.is_reusable: del self.detokenizer |
|
|
|
|
image = Image.fromarray(image.to('cpu').detach().numpy()) |
|
|
|
|
return image |