clamp tokens to appropriate bounds
This commit is contained in:
parent
e409c120d0
commit
9eb5633931
2
cog.yaml
vendored
2
cog.yaml
vendored
|
@ -6,7 +6,7 @@ build:
|
||||||
- "libgl1-mesa-glx"
|
- "libgl1-mesa-glx"
|
||||||
- "libglib2.0-0"
|
- "libglib2.0-0"
|
||||||
python_packages:
|
python_packages:
|
||||||
- "min-dalle==0.3.3"
|
- "min-dalle==0.3.4"
|
||||||
run:
|
run:
|
||||||
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
- pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
|
||||||
|
|
|
@ -117,6 +117,7 @@ class DalleBartDecoder(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_count = layer_count
|
self.layer_count = layer_count
|
||||||
self.embed_count = embed_count
|
self.embed_count = embed_count
|
||||||
|
self.image_vocab_count = image_vocab_count
|
||||||
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
|
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
|
||||||
self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count)
|
self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count)
|
||||||
self.layers: List[DecoderLayer] = nn.ModuleList([
|
self.layers: List[DecoderLayer] = nn.ModuleList([
|
||||||
|
@ -152,6 +153,7 @@ class DalleBartDecoder(nn.Module):
|
||||||
image_count = encoder_state.shape[0] // 2
|
image_count = encoder_state.shape[0] // 2
|
||||||
token_index_batched = token_index[[0] * image_count * 2]
|
token_index_batched = token_index[[0] * image_count * 2]
|
||||||
prev_tokens = prev_tokens[list(range(image_count)) * 2]
|
prev_tokens = prev_tokens[list(range(image_count)) * 2]
|
||||||
|
prev_tokens = prev_tokens.clamp(0, self.image_vocab_count)
|
||||||
decoder_state = self.embed_tokens.forward(prev_tokens)
|
decoder_state = self.embed_tokens.forward(prev_tokens)
|
||||||
decoder_state += self.embed_positions.forward(token_index_batched)
|
decoder_state += self.embed_positions.forward(token_index_batched)
|
||||||
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
||||||
|
|
|
@ -119,6 +119,7 @@ class DalleBartEncoder(nn.Module):
|
||||||
glu_embed_count: int
|
glu_embed_count: int
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.text_vocab_count = text_vocab_count
|
||||||
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
|
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
|
||||||
self.embed_positions = nn.Embedding(text_token_count, embed_count)
|
self.embed_positions = nn.Embedding(text_token_count, embed_count)
|
||||||
self.layers: List[EncoderLayer] = nn.ModuleList([
|
self.layers: List[EncoderLayer] = nn.ModuleList([
|
||||||
|
@ -138,6 +139,7 @@ class DalleBartEncoder(nn.Module):
|
||||||
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
||||||
attention_mask = text_tokens.not_equal(1)
|
attention_mask = text_tokens.not_equal(1)
|
||||||
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
|
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
|
||||||
|
text_tokens = text_tokens.clamp(0, self.text_vocab_count - 1)
|
||||||
encoder_state = (
|
encoder_state = (
|
||||||
self.embed_tokens.forward(text_tokens) +
|
self.embed_tokens.forward(text_tokens) +
|
||||||
self.embed_positions.forward(pose_tokens)
|
self.embed_positions.forward(pose_tokens)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import FloatTensor
|
from torch import FloatTensor, LongTensor
|
||||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
@ -160,12 +160,14 @@ class Decoder(Module):
|
||||||
class VQGanDetokenizer(Module):
|
class VQGanDetokenizer(Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
m, n = 2 ** 14, 2 ** 8
|
vocab_count, embed_count = 2 ** 14, 2 ** 8
|
||||||
self.embedding = Embedding(m, n)
|
self.vocab_count = vocab_count
|
||||||
self.post_quant_conv = Conv2d(n, n, 1)
|
self.embedding = Embedding(vocab_count, embed_count)
|
||||||
|
self.post_quant_conv = Conv2d(embed_count, embed_count, 1)
|
||||||
self.decoder = Decoder()
|
self.decoder = Decoder()
|
||||||
|
|
||||||
def forward(self, z: FloatTensor) -> FloatTensor:
|
def forward(self, z: LongTensor) -> FloatTensor:
|
||||||
|
z = z.clamp(0, self.vocab_count - 1)
|
||||||
z = self.embedding.forward(z)
|
z = self.embedding.forward(z)
|
||||||
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
|
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
|
||||||
z = z.permute(0, 3, 1, 2).contiguous()
|
z = z.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -5,7 +5,7 @@ setuptools.setup(
|
||||||
name='min-dalle',
|
name='min-dalle',
|
||||||
description = 'min(DALL·E)',
|
description = 'min(DALL·E)',
|
||||||
# long_description=(Path(__file__).parent / "README.rst").read_text(),
|
# long_description=(Path(__file__).parent / "README.rst").read_text(),
|
||||||
version='0.3.3',
|
version='0.3.4',
|
||||||
author='Brett Kuprel',
|
author='Brett Kuprel',
|
||||||
author_email='brkuprel@gmail.com',
|
author_email='brkuprel@gmail.com',
|
||||||
url='https://github.com/kuprel/min-dalle',
|
url='https://github.com/kuprel/min-dalle',
|
||||||
|
|
Loading…
Reference in New Issue
Block a user