clamp in place

main
Brett Kuprel 2 years ago
parent 247c41be17
commit 1ffdef9a56
  1. 2
      cog.yaml
  2. 2
      min_dalle/models/dalle_bart_decoder.py
  3. 2
      min_dalle/models/dalle_bart_encoder.py
  4. 2
      min_dalle/models/vqgan_detokenizer.py
  5. 2
      replicate_predictor.py
  6. 4
      setup.py

2
cog.yaml vendored

@ -6,7 +6,7 @@ build:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "min-dalle==0.3.7"
- "min-dalle==0.3.11"
run:
- pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html

@ -151,7 +151,7 @@ class DalleBartDecoder(nn.Module):
image_count = encoder_state.shape[0] // 2
token_index_batched = token_index[[0] * image_count * 2]
prev_tokens = prev_tokens[list(range(image_count)) * 2]
prev_tokens = prev_tokens.clamp(0, self.image_vocab_count)
prev_tokens.clamp_(0, self.image_vocab_count)
decoder_state = self.embed_tokens.forward(prev_tokens)
decoder_state += self.embed_positions.forward(token_index_batched)
decoder_state = self.layernorm_embedding.forward(decoder_state)

@ -138,7 +138,7 @@ class DalleBartEncoder(nn.Module):
def forward(self, text_tokens: LongTensor) -> FloatTensor:
attention_mask = text_tokens.not_equal(1)
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
text_tokens = text_tokens.clamp(0, self.text_vocab_count - 1)
text_tokens.clamp_(0, self.text_vocab_count - 1)
encoder_state = (
self.embed_tokens.forward(text_tokens) +
self.embed_positions.forward(pose_tokens)

@ -166,7 +166,7 @@ class VQGanDetokenizer(Module):
self.decoder = Decoder()
def forward(self, z: LongTensor) -> FloatTensor:
z = z.clamp(0, self.vocab_count - 1)
z.clamp_(0, self.vocab_count - 1)
z = self.embedding.forward(z)
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
z = z.permute(0, 3, 1, 2).contiguous()

@ -22,7 +22,7 @@ class ReplicatePredictor(BasePredictor):
default=True
),
grid_size: int = Input(
description='Size of the image grid. 4x4 takes about 15 seconds, 8x8 takes about 35 seconds',
description='Size of the image grid. 5x5 takes around 16 seconds, 8x8 takes around 36 seconds',
ge=1,
le=8,
default=4

@ -1,11 +1,11 @@
import setuptools
from pathlib import Path
# from pathlib import Path
setuptools.setup(
name='min-dalle',
description = 'min(DALL·E)',
# long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.3.7',
version='0.3.11',
author='Brett Kuprel',
author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle',

Loading…
Cancel
Save