You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
1.4 KiB
75 lines
1.4 KiB
from min_dalle import MinDalle |
|
import torch |
|
import time |
|
|
|
NUM_TESTS = 50 |
|
|
|
def generate_image( |
|
is_mega: bool, |
|
text: str, |
|
seed: int, |
|
grid_size: int, |
|
top_k: int, |
|
image_path: str, |
|
models_root: str, |
|
fp16: bool, |
|
): |
|
model = MinDalle( |
|
is_mega=is_mega, |
|
models_root=models_root, |
|
is_reusable=False, |
|
is_verbose=True, |
|
dtype=torch.float16 if fp16 else torch.float32 |
|
) |
|
|
|
image = model.generate_image( |
|
text, |
|
seed, |
|
grid_size, |
|
top_k=top_k, |
|
is_verbose=True |
|
) |
|
|
|
def run_dalle(): |
|
generate_image( |
|
is_mega=True, |
|
text='rich ducks playing poker', |
|
seed=0, |
|
grid_size=3, |
|
top_k=256, |
|
image_path='generated', |
|
models_root='pretrained', |
|
fp16=True, |
|
) |
|
|
|
if __name__ == '__main__': |
|
times = [] |
|
|
|
print('Disregarding first two tests...') |
|
run_dalle() |
|
run_dalle() |
|
|
|
print('Running dalle', NUM_TESTS, 'times...') |
|
|
|
for i in range(NUM_TESTS): |
|
start = time.time() |
|
|
|
print() |
|
print('Running test', i+1, '/', NUM_TESTS, '...') |
|
|
|
run_dalle() |
|
|
|
duration = time.time() - start |
|
print(' Completed in', duration, 's.') |
|
times.append(duration) |
|
|
|
print() |
|
print() |
|
print('Run times:') |
|
for t in times: |
|
print(t) |
|
|
|
average = sum(times) / NUM_TESTS |
|
print() |
|
print('Average:', average, '| Max:', max(times), '| Min:', min(times)) |
|
|
|
|