from min_dalle import MinDalle import torch import time NUM_TESTS = 50 def generate_image( is_mega: bool, text: str, seed: int, grid_size: int, top_k: int, image_path: str, models_root: str, fp16: bool, ): model = MinDalle( is_mega=is_mega, models_root=models_root, is_reusable=False, is_verbose=True, dtype=torch.float16 if fp16 else torch.float32 ) image = model.generate_image( text, seed, grid_size, top_k=top_k, is_verbose=True ) def run_dalle(): generate_image( is_mega=True, text='rich ducks playing poker', seed=0, grid_size=3, top_k=256, image_path='generated', models_root='pretrained', fp16=True, ) if __name__ == '__main__': times = [] print('Disregarding first two tests...') run_dalle() run_dalle() print('Running dalle', NUM_TESTS, 'times...') for i in range(NUM_TESTS): start = time.time() print() print('Running test', i+1, '/', NUM_TESTS, '...') run_dalle() duration = time.time() - start print(' Completed in', duration, 's.') times.append(duration) print() print() print('Run times:') for t in times: print(t) average = sum(times) / NUM_TESTS print() print('Average:', average, '| Max:', max(times), '| Min:', min(times))