clamp tokens to appropriate bounds

This commit is contained in:
Brett Kuprel 2022-07-08 09:19:28 -04:00
parent e409c120d0
commit 9eb5633931
5 changed files with 13 additions and 7 deletions

2
cog.yaml vendored
View File

@ -6,7 +6,7 @@ build:
- "libgl1-mesa-glx" - "libgl1-mesa-glx"
- "libglib2.0-0" - "libglib2.0-0"
python_packages: python_packages:
- "min-dalle==0.3.3" - "min-dalle==0.3.4"
run: run:
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html - pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html

View File

@ -117,6 +117,7 @@ class DalleBartDecoder(nn.Module):
super().__init__() super().__init__()
self.layer_count = layer_count self.layer_count = layer_count
self.embed_count = embed_count self.embed_count = embed_count
self.image_vocab_count = image_vocab_count
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count) self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count)
self.layers: List[DecoderLayer] = nn.ModuleList([ self.layers: List[DecoderLayer] = nn.ModuleList([
@ -152,6 +153,7 @@ class DalleBartDecoder(nn.Module):
image_count = encoder_state.shape[0] // 2 image_count = encoder_state.shape[0] // 2
token_index_batched = token_index[[0] * image_count * 2] token_index_batched = token_index[[0] * image_count * 2]
prev_tokens = prev_tokens[list(range(image_count)) * 2] prev_tokens = prev_tokens[list(range(image_count)) * 2]
prev_tokens = prev_tokens.clamp(0, self.image_vocab_count)
decoder_state = self.embed_tokens.forward(prev_tokens) decoder_state = self.embed_tokens.forward(prev_tokens)
decoder_state += self.embed_positions.forward(token_index_batched) decoder_state += self.embed_positions.forward(token_index_batched)
decoder_state = self.layernorm_embedding.forward(decoder_state) decoder_state = self.layernorm_embedding.forward(decoder_state)

View File

@ -119,6 +119,7 @@ class DalleBartEncoder(nn.Module):
glu_embed_count: int glu_embed_count: int
): ):
super().__init__() super().__init__()
self.text_vocab_count = text_vocab_count
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
self.embed_positions = nn.Embedding(text_token_count, embed_count) self.embed_positions = nn.Embedding(text_token_count, embed_count)
self.layers: List[EncoderLayer] = nn.ModuleList([ self.layers: List[EncoderLayer] = nn.ModuleList([
@ -138,6 +139,7 @@ class DalleBartEncoder(nn.Module):
def forward(self, text_tokens: LongTensor) -> FloatTensor: def forward(self, text_tokens: LongTensor) -> FloatTensor:
attention_mask = text_tokens.not_equal(1) attention_mask = text_tokens.not_equal(1)
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]] pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
text_tokens = text_tokens.clamp(0, self.text_vocab_count - 1)
encoder_state = ( encoder_state = (
self.embed_tokens.forward(text_tokens) + self.embed_tokens.forward(text_tokens) +
self.embed_positions.forward(pose_tokens) self.embed_positions.forward(pose_tokens)

View File

@ -1,5 +1,5 @@
import torch import torch
from torch import FloatTensor from torch import FloatTensor, LongTensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
@ -160,12 +160,14 @@ class Decoder(Module):
class VQGanDetokenizer(Module): class VQGanDetokenizer(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
m, n = 2 ** 14, 2 ** 8 vocab_count, embed_count = 2 ** 14, 2 ** 8
self.embedding = Embedding(m, n) self.vocab_count = vocab_count
self.post_quant_conv = Conv2d(n, n, 1) self.embedding = Embedding(vocab_count, embed_count)
self.post_quant_conv = Conv2d(embed_count, embed_count, 1)
self.decoder = Decoder() self.decoder = Decoder()
def forward(self, z: FloatTensor) -> FloatTensor: def forward(self, z: LongTensor) -> FloatTensor:
z = z.clamp(0, self.vocab_count - 1)
z = self.embedding.forward(z) z = self.embedding.forward(z)
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8)) z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
z = z.permute(0, 3, 1, 2).contiguous() z = z.permute(0, 3, 1, 2).contiguous()

View File

@ -5,7 +5,7 @@ setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
# long_description=(Path(__file__).parent / "README.rst").read_text(), # long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.3.3', version='0.3.4',
author='Brett Kuprel', author='Brett Kuprel',
author_email='brkuprel@gmail.com', author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle', url='https://github.com/kuprel/min-dalle',