76 lines
1.4 KiB
Python
76 lines
1.4 KiB
Python
from min_dalle import MinDalle
|
|
import torch
|
|
import time
|
|
|
|
NUM_TESTS = 50
|
|
|
|
def generate_image(
|
|
is_mega: bool,
|
|
text: str,
|
|
seed: int,
|
|
grid_size: int,
|
|
top_k: int,
|
|
image_path: str,
|
|
models_root: str,
|
|
fp16: bool,
|
|
):
|
|
model = MinDalle(
|
|
is_mega=is_mega,
|
|
models_root=models_root,
|
|
is_reusable=False,
|
|
is_verbose=True,
|
|
dtype=torch.float16 if fp16 else torch.float32
|
|
)
|
|
|
|
image = model.generate_image(
|
|
text,
|
|
seed,
|
|
grid_size,
|
|
top_k=top_k,
|
|
is_verbose=True
|
|
)
|
|
|
|
def run_dalle():
|
|
generate_image(
|
|
is_mega=True,
|
|
text='rich ducks playing poker',
|
|
seed=0,
|
|
grid_size=3,
|
|
top_k=256,
|
|
image_path='generated',
|
|
models_root='pretrained',
|
|
fp16=True,
|
|
)
|
|
|
|
if __name__ == '__main__':
|
|
times = []
|
|
|
|
print('Disregarding first two tests...')
|
|
run_dalle()
|
|
run_dalle()
|
|
|
|
print('Running dalle', NUM_TESTS, 'times...')
|
|
|
|
for i in range(NUM_TESTS):
|
|
start = time.time()
|
|
|
|
print()
|
|
print('Running test', i+1, '/', NUM_TESTS, '...')
|
|
|
|
run_dalle()
|
|
|
|
duration = time.time() - start
|
|
print(' Completed in', duration, 's.')
|
|
times.append(duration)
|
|
|
|
print()
|
|
print()
|
|
print('Run times:')
|
|
for t in times:
|
|
print(t)
|
|
|
|
average = sum(times) / NUM_TESTS
|
|
print()
|
|
print('Average:', average, '| Max:', max(times), '| Min:', min(times))
|
|
|