properly limit input to 64 tokens

This commit is contained in:
Brett Kuprel 2022-07-05 22:14:19 -04:00
parent 1a8c01047c
commit f071b31bdd
4 changed files with 6 additions and 4 deletions

4
cog.yaml vendored
View File

@ -1,12 +1,12 @@
build: build:
cuda: "11.0" cuda: "11.3"
gpu: true gpu: true
python_version: "3.8" python_version: "3.8"
system_packages: system_packages:
- "libgl1-mesa-glx" - "libgl1-mesa-glx"
- "libglib2.0-0" - "libglib2.0-0"
python_packages: python_packages:
- "min-dalle==0.2.28" - "min-dalle==0.2.29"
run: run:
- pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html - pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html

View File

@ -172,6 +172,8 @@ class MinDalle:
assert(log2_mid_count in range(5)) assert(log2_mid_count in range(5))
if is_verbose: print("tokenizing text") if is_verbose: print("tokenizing text")
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
if len(tokens) > self.text_token_count:
tokens = tokens[:self.text_token_count]
if is_verbose: print("text tokens", tokens) if is_verbose: print("text tokens", tokens)
text_tokens = numpy.ones((2, 64), dtype=numpy.int32) text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
text_tokens[0, :2] = [tokens[0], tokens[-1]] text_tokens[0, :2] = [tokens[0], tokens[-1]]

View File

@ -11,7 +11,7 @@ class ReplicatePredictor(BasePredictor):
def predict( def predict(
self, self,
text: str = Input( text: str = Input(
description='Text', description='For long prompts, only the first 64 tokens will be used to generate the image.',
default='Dali painting of WALL·E' default='Dali painting of WALL·E'
), ),
intermediate_outputs: bool = Input( intermediate_outputs: bool = Input(

View File

@ -5,7 +5,7 @@ setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
long_description=(Path(__file__).parent / "README.rst").read_text(), long_description=(Path(__file__).parent / "README.rst").read_text(),
version='0.2.28', version='0.2.29',
author='Brett Kuprel', author='Brett Kuprel',
author_email='brkuprel@gmail.com', author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle', url='https://github.com/kuprel/min-dalle',