clamp in place
This commit is contained in:
parent
247c41be17
commit
1ffdef9a56
2
cog.yaml
vendored
2
cog.yaml
vendored
|
@ -6,7 +6,7 @@ build:
|
|||
- "libgl1-mesa-glx"
|
||||
- "libglib2.0-0"
|
||||
python_packages:
|
||||
- "min-dalle==0.3.7"
|
||||
- "min-dalle==0.3.11"
|
||||
run:
|
||||
- pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
|
|
|
@ -151,7 +151,7 @@ class DalleBartDecoder(nn.Module):
|
|||
image_count = encoder_state.shape[0] // 2
|
||||
token_index_batched = token_index[[0] * image_count * 2]
|
||||
prev_tokens = prev_tokens[list(range(image_count)) * 2]
|
||||
prev_tokens = prev_tokens.clamp(0, self.image_vocab_count)
|
||||
prev_tokens.clamp_(0, self.image_vocab_count)
|
||||
decoder_state = self.embed_tokens.forward(prev_tokens)
|
||||
decoder_state += self.embed_positions.forward(token_index_batched)
|
||||
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
||||
|
|
|
@ -138,7 +138,7 @@ class DalleBartEncoder(nn.Module):
|
|||
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
|
||||
text_tokens = text_tokens.clamp(0, self.text_vocab_count - 1)
|
||||
text_tokens.clamp_(0, self.text_vocab_count - 1)
|
||||
encoder_state = (
|
||||
self.embed_tokens.forward(text_tokens) +
|
||||
self.embed_positions.forward(pose_tokens)
|
||||
|
|
|
@ -166,7 +166,7 @@ class VQGanDetokenizer(Module):
|
|||
self.decoder = Decoder()
|
||||
|
||||
def forward(self, z: LongTensor) -> FloatTensor:
|
||||
z = z.clamp(0, self.vocab_count - 1)
|
||||
z.clamp_(0, self.vocab_count - 1)
|
||||
z = self.embedding.forward(z)
|
||||
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
|
||||
z = z.permute(0, 3, 1, 2).contiguous()
|
||||
|
|
|
@ -22,7 +22,7 @@ class ReplicatePredictor(BasePredictor):
|
|||
default=True
|
||||
),
|
||||
grid_size: int = Input(
|
||||
description='Size of the image grid. 4x4 takes about 15 seconds, 8x8 takes about 35 seconds',
|
||||
description='Size of the image grid. 5x5 takes around 16 seconds, 8x8 takes around 36 seconds',
|
||||
ge=1,
|
||||
le=8,
|
||||
default=4
|
||||
|
|
4
setup.py
4
setup.py
|
@ -1,11 +1,11 @@
|
|||
import setuptools
|
||||
from pathlib import Path
|
||||
# from pathlib import Path
|
||||
|
||||
setuptools.setup(
|
||||
name='min-dalle',
|
||||
description = 'min(DALL·E)',
|
||||
# long_description=(Path(__file__).parent / "README.rst").read_text(),
|
||||
version='0.3.7',
|
||||
version='0.3.11',
|
||||
author='Brett Kuprel',
|
||||
author_email='brkuprel@gmail.com',
|
||||
url='https://github.com/kuprel/min-dalle',
|
||||
|
|
Loading…
Reference in New Issue
Block a user