Add script to run 100 tests
This commit is contained in:
parent
d29ecbd933
commit
e70aa360f3
75
test100.py
Normal file
75
test100.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
from min_dalle import MinDalle
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
|
||||||
|
NUM_TESTS = 50
|
||||||
|
|
||||||
|
def generate_image(
|
||||||
|
is_mega: bool,
|
||||||
|
text: str,
|
||||||
|
seed: int,
|
||||||
|
grid_size: int,
|
||||||
|
top_k: int,
|
||||||
|
image_path: str,
|
||||||
|
models_root: str,
|
||||||
|
fp16: bool,
|
||||||
|
):
|
||||||
|
model = MinDalle(
|
||||||
|
is_mega=is_mega,
|
||||||
|
models_root=models_root,
|
||||||
|
is_reusable=False,
|
||||||
|
is_verbose=True,
|
||||||
|
dtype=torch.float16 if fp16 else torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
image = model.generate_image(
|
||||||
|
text,
|
||||||
|
seed,
|
||||||
|
grid_size,
|
||||||
|
top_k=top_k,
|
||||||
|
is_verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_dalle():
|
||||||
|
generate_image(
|
||||||
|
is_mega=True,
|
||||||
|
text='rich ducks playing poker',
|
||||||
|
seed=0,
|
||||||
|
grid_size=3,
|
||||||
|
top_k=256,
|
||||||
|
image_path='generated',
|
||||||
|
models_root='pretrained',
|
||||||
|
fp16=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
times = []
|
||||||
|
|
||||||
|
print('Disregarding first two tests...')
|
||||||
|
run_dalle()
|
||||||
|
run_dalle()
|
||||||
|
|
||||||
|
print('Running dalle', NUM_TESTS, 'times...')
|
||||||
|
|
||||||
|
for i in range(NUM_TESTS):
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
print()
|
||||||
|
print('Running test', i+1, '/', NUM_TESTS, '...')
|
||||||
|
|
||||||
|
run_dalle()
|
||||||
|
|
||||||
|
duration = time.time() - start
|
||||||
|
print(' Completed in', duration, 's.')
|
||||||
|
times.append(duration)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print()
|
||||||
|
print('Run times:')
|
||||||
|
for t in times:
|
||||||
|
print(t)
|
||||||
|
|
||||||
|
average = sum(times) / NUM_TESTS
|
||||||
|
print()
|
||||||
|
print('Average:', average, '| Max:', max(times), '| Min:', min(times))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user