Add script to run 100 tests
This commit is contained in:
parent
d29ecbd933
commit
e70aa360f3
75
test100.py
Normal file
75
test100.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
from min_dalle import MinDalle
|
||||
import torch
|
||||
import time
|
||||
|
||||
NUM_TESTS = 50
|
||||
|
||||
def generate_image(
|
||||
is_mega: bool,
|
||||
text: str,
|
||||
seed: int,
|
||||
grid_size: int,
|
||||
top_k: int,
|
||||
image_path: str,
|
||||
models_root: str,
|
||||
fp16: bool,
|
||||
):
|
||||
model = MinDalle(
|
||||
is_mega=is_mega,
|
||||
models_root=models_root,
|
||||
is_reusable=False,
|
||||
is_verbose=True,
|
||||
dtype=torch.float16 if fp16 else torch.float32
|
||||
)
|
||||
|
||||
image = model.generate_image(
|
||||
text,
|
||||
seed,
|
||||
grid_size,
|
||||
top_k=top_k,
|
||||
is_verbose=True
|
||||
)
|
||||
|
||||
def run_dalle():
|
||||
generate_image(
|
||||
is_mega=True,
|
||||
text='rich ducks playing poker',
|
||||
seed=0,
|
||||
grid_size=3,
|
||||
top_k=256,
|
||||
image_path='generated',
|
||||
models_root='pretrained',
|
||||
fp16=True,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
times = []
|
||||
|
||||
print('Disregarding first two tests...')
|
||||
run_dalle()
|
||||
run_dalle()
|
||||
|
||||
print('Running dalle', NUM_TESTS, 'times...')
|
||||
|
||||
for i in range(NUM_TESTS):
|
||||
start = time.time()
|
||||
|
||||
print()
|
||||
print('Running test', i+1, '/', NUM_TESTS, '...')
|
||||
|
||||
run_dalle()
|
||||
|
||||
duration = time.time() - start
|
||||
print(' Completed in', duration, 's.')
|
||||
times.append(duration)
|
||||
|
||||
print()
|
||||
print()
|
||||
print('Run times:')
|
||||
for t in times:
|
||||
print(t)
|
||||
|
||||
average = sum(times) / NUM_TESTS
|
||||
print()
|
||||
print('Average:', average, '| Max:', max(times), '| Min:', min(times))
|
||||
|
Loading…
Reference in New Issue
Block a user