diff --git a/test100.py b/test100.py new file mode 100644 index 0000000..a119efe --- /dev/null +++ b/test100.py @@ -0,0 +1,75 @@ +from min_dalle import MinDalle +import torch +import time + +NUM_TESTS = 50 + +def generate_image( + is_mega: bool, + text: str, + seed: int, + grid_size: int, + top_k: int, + image_path: str, + models_root: str, + fp16: bool, +): + model = MinDalle( + is_mega=is_mega, + models_root=models_root, + is_reusable=False, + is_verbose=True, + dtype=torch.float16 if fp16 else torch.float32 + ) + + image = model.generate_image( + text, + seed, + grid_size, + top_k=top_k, + is_verbose=True + ) + +def run_dalle(): + generate_image( + is_mega=True, + text='rich ducks playing poker', + seed=0, + grid_size=3, + top_k=256, + image_path='generated', + models_root='pretrained', + fp16=True, + ) + +if __name__ == '__main__': + times = [] + + print('Disregarding first two tests...') + run_dalle() + run_dalle() + + print('Running dalle', NUM_TESTS, 'times...') + + for i in range(NUM_TESTS): + start = time.time() + + print() + print('Running test', i+1, '/', NUM_TESTS, '...') + + run_dalle() + + duration = time.time() - start + print(' Completed in', duration, 's.') + times.append(duration) + + print() + print() + print('Run times:') + for t in times: + print(t) + + average = sum(times) / NUM_TESTS + print() + print('Average:', average, '| Max:', max(times), '| Min:', min(times)) +