From 9eb5633931359aa8edd42ed0311f47943f8b1242 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Fri, 8 Jul 2022 09:19:28 -0400 Subject: [PATCH] clamp tokens to appropriate bounds --- cog.yaml | 2 +- min_dalle/models/dalle_bart_decoder.py | 2 ++ min_dalle/models/dalle_bart_encoder.py | 2 ++ min_dalle/models/vqgan_detokenizer.py | 12 +++++++----- setup.py | 2 +- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/cog.yaml b/cog.yaml index 4aa3f59..0541f8d 100644 --- a/cog.yaml +++ b/cog.yaml @@ -6,7 +6,7 @@ build: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.3.3" + - "min-dalle==0.3.4" run: - pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index 5d34bac..917f354 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -117,6 +117,7 @@ class DalleBartDecoder(nn.Module): super().__init__() self.layer_count = layer_count self.embed_count = embed_count + self.image_vocab_count = image_vocab_count self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count) self.layers: List[DecoderLayer] = nn.ModuleList([ @@ -152,6 +153,7 @@ class DalleBartDecoder(nn.Module): image_count = encoder_state.shape[0] // 2 token_index_batched = token_index[[0] * image_count * 2] prev_tokens = prev_tokens[list(range(image_count)) * 2] + prev_tokens = prev_tokens.clamp(0, self.image_vocab_count) decoder_state = self.embed_tokens.forward(prev_tokens) decoder_state += self.embed_positions.forward(token_index_batched) decoder_state = self.layernorm_embedding.forward(decoder_state) diff --git a/min_dalle/models/dalle_bart_encoder.py b/min_dalle/models/dalle_bart_encoder.py index a96cd6b..7081fa3 100644 --- a/min_dalle/models/dalle_bart_encoder.py +++ b/min_dalle/models/dalle_bart_encoder.py @@ -119,6 +119,7 @@ class DalleBartEncoder(nn.Module): glu_embed_count: int ): super().__init__() + self.text_vocab_count = text_vocab_count self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) self.embed_positions = nn.Embedding(text_token_count, embed_count) self.layers: List[EncoderLayer] = nn.ModuleList([ @@ -138,6 +139,7 @@ class DalleBartEncoder(nn.Module): def forward(self, text_tokens: LongTensor) -> FloatTensor: attention_mask = text_tokens.not_equal(1) pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]] + text_tokens = text_tokens.clamp(0, self.text_vocab_count - 1) encoder_state = ( self.embed_tokens.forward(text_tokens) + self.embed_positions.forward(pose_tokens) diff --git a/min_dalle/models/vqgan_detokenizer.py b/min_dalle/models/vqgan_detokenizer.py index 19881b5..c3ffebb 100644 --- a/min_dalle/models/vqgan_detokenizer.py +++ b/min_dalle/models/vqgan_detokenizer.py @@ -1,5 +1,5 @@ import torch -from torch import FloatTensor +from torch import FloatTensor, LongTensor from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding torch.set_grad_enabled(False) @@ -160,12 +160,14 @@ class Decoder(Module): class VQGanDetokenizer(Module): def __init__(self): super().__init__() - m, n = 2 ** 14, 2 ** 8 - self.embedding = Embedding(m, n) - self.post_quant_conv = Conv2d(n, n, 1) + vocab_count, embed_count = 2 ** 14, 2 ** 8 + self.vocab_count = vocab_count + self.embedding = Embedding(vocab_count, embed_count) + self.post_quant_conv = Conv2d(embed_count, embed_count, 1) self.decoder = Decoder() - def forward(self, z: FloatTensor) -> FloatTensor: + def forward(self, z: LongTensor) -> FloatTensor: + z = z.clamp(0, self.vocab_count - 1) z = self.embedding.forward(z) z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8)) z = z.permute(0, 3, 1, 2).contiguous() diff --git a/setup.py b/setup.py index ebf9cd9..9675cb7 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALLĀ·E)', # long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.3.3', + version='0.3.4', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',