|
|
|
@ -1,5 +1,5 @@ |
|
|
|
|
import torch |
|
|
|
|
from torch import FloatTensor |
|
|
|
|
from torch import FloatTensor, LongTensor |
|
|
|
|
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding |
|
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
|
|
|
@ -160,12 +160,14 @@ class Decoder(Module): |
|
|
|
|
class VQGanDetokenizer(Module): |
|
|
|
|
def __init__(self): |
|
|
|
|
super().__init__() |
|
|
|
|
m, n = 2 ** 14, 2 ** 8 |
|
|
|
|
self.embedding = Embedding(m, n) |
|
|
|
|
self.post_quant_conv = Conv2d(n, n, 1) |
|
|
|
|
vocab_count, embed_count = 2 ** 14, 2 ** 8 |
|
|
|
|
self.vocab_count = vocab_count |
|
|
|
|
self.embedding = Embedding(vocab_count, embed_count) |
|
|
|
|
self.post_quant_conv = Conv2d(embed_count, embed_count, 1) |
|
|
|
|
self.decoder = Decoder() |
|
|
|
|
|
|
|
|
|
def forward(self, z: FloatTensor) -> FloatTensor: |
|
|
|
|
def forward(self, z: LongTensor) -> FloatTensor: |
|
|
|
|
z = z.clamp(0, self.vocab_count - 1) |
|
|
|
|
z = self.embedding.forward(z) |
|
|
|
|
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8)) |
|
|
|
|
z = z.permute(0, 3, 1, 2).contiguous() |
|
|
|
|