From 1ffdef9a565cf5198a00f648edb5aa12564f182b Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Sun, 10 Jul 2022 08:07:54 -0400 Subject: [PATCH] clamp in place --- cog.yaml | 2 +- min_dalle/models/dalle_bart_decoder.py | 2 +- min_dalle/models/dalle_bart_encoder.py | 2 +- min_dalle/models/vqgan_detokenizer.py | 2 +- replicate_predictor.py | 2 +- setup.py | 4 ++-- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cog.yaml b/cog.yaml index bed81dc..430c32a 100644 --- a/cog.yaml +++ b/cog.yaml @@ -6,7 +6,7 @@ build: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.3.7" + - "min-dalle==0.3.11" run: - pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index 9e75e4f..1dcee20 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -151,7 +151,7 @@ class DalleBartDecoder(nn.Module): image_count = encoder_state.shape[0] // 2 token_index_batched = token_index[[0] * image_count * 2] prev_tokens = prev_tokens[list(range(image_count)) * 2] - prev_tokens = prev_tokens.clamp(0, self.image_vocab_count) + prev_tokens.clamp_(0, self.image_vocab_count) decoder_state = self.embed_tokens.forward(prev_tokens) decoder_state += self.embed_positions.forward(token_index_batched) decoder_state = self.layernorm_embedding.forward(decoder_state) diff --git a/min_dalle/models/dalle_bart_encoder.py b/min_dalle/models/dalle_bart_encoder.py index cb93389..4ee9045 100644 --- a/min_dalle/models/dalle_bart_encoder.py +++ b/min_dalle/models/dalle_bart_encoder.py @@ -138,7 +138,7 @@ class DalleBartEncoder(nn.Module): def forward(self, text_tokens: LongTensor) -> FloatTensor: attention_mask = text_tokens.not_equal(1) pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]] - text_tokens = text_tokens.clamp(0, self.text_vocab_count - 1) + text_tokens.clamp_(0, self.text_vocab_count - 1) encoder_state = ( self.embed_tokens.forward(text_tokens) + self.embed_positions.forward(pose_tokens) diff --git a/min_dalle/models/vqgan_detokenizer.py b/min_dalle/models/vqgan_detokenizer.py index 9c9ee48..5eb7ca0 100644 --- a/min_dalle/models/vqgan_detokenizer.py +++ b/min_dalle/models/vqgan_detokenizer.py @@ -166,7 +166,7 @@ class VQGanDetokenizer(Module): self.decoder = Decoder() def forward(self, z: LongTensor) -> FloatTensor: - z = z.clamp(0, self.vocab_count - 1) + z.clamp_(0, self.vocab_count - 1) z = self.embedding.forward(z) z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8)) z = z.permute(0, 3, 1, 2).contiguous() diff --git a/replicate_predictor.py b/replicate_predictor.py index 39018ec..22d2c1b 100644 --- a/replicate_predictor.py +++ b/replicate_predictor.py @@ -22,7 +22,7 @@ class ReplicatePredictor(BasePredictor): default=True ), grid_size: int = Input( - description='Size of the image grid. 4x4 takes about 15 seconds, 8x8 takes about 35 seconds', + description='Size of the image grid. 5x5 takes around 16 seconds, 8x8 takes around 36 seconds', ge=1, le=8, default=4 diff --git a/setup.py b/setup.py index b408af3..e248c52 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,11 @@ import setuptools -from pathlib import Path +# from pathlib import Path setuptools.setup( name='min-dalle', description = 'min(DALLĀ·E)', # long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.3.7', + version='0.3.11', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',