min-dalle-test/test100.py

76 lines
1.4 KiB
Python
Raw Normal View History

2022-07-19 20:50:37 +00:00
from min_dalle import MinDalle
import torch
import time
NUM_TESTS = 50
def generate_image(
is_mega: bool,
text: str,
seed: int,
grid_size: int,
top_k: int,
image_path: str,
models_root: str,
fp16: bool,
):
model = MinDalle(
is_mega=is_mega,
models_root=models_root,
is_reusable=False,
is_verbose=True,
dtype=torch.float16 if fp16 else torch.float32
)
image = model.generate_image(
text,
seed,
grid_size,
top_k=top_k,
is_verbose=True
)
def run_dalle():
generate_image(
is_mega=True,
text='rich ducks playing poker',
seed=0,
grid_size=3,
top_k=256,
image_path='generated',
models_root='pretrained',
fp16=True,
)
if __name__ == '__main__':
times = []
print('Disregarding first two tests...')
run_dalle()
run_dalle()
print('Running dalle', NUM_TESTS, 'times...')
for i in range(NUM_TESTS):
start = time.time()
print()
print('Running test', i+1, '/', NUM_TESTS, '...')
run_dalle()
duration = time.time() - start
print(' Completed in', duration, 's.')
times.append(duration)
print()
print()
print('Run times:')
for t in times:
print(t)
average = sum(times) / NUM_TESTS
print()
print('Average:', average, '| Max:', max(times), '| Min:', min(times))