parent
d29ecbd933
commit
e70aa360f3
1 changed files with 75 additions and 0 deletions
@ -0,0 +1,75 @@ |
||||
from min_dalle import MinDalle |
||||
import torch |
||||
import time |
||||
|
||||
NUM_TESTS = 50 |
||||
|
||||
def generate_image( |
||||
is_mega: bool, |
||||
text: str, |
||||
seed: int, |
||||
grid_size: int, |
||||
top_k: int, |
||||
image_path: str, |
||||
models_root: str, |
||||
fp16: bool, |
||||
): |
||||
model = MinDalle( |
||||
is_mega=is_mega, |
||||
models_root=models_root, |
||||
is_reusable=False, |
||||
is_verbose=True, |
||||
dtype=torch.float16 if fp16 else torch.float32 |
||||
) |
||||
|
||||
image = model.generate_image( |
||||
text, |
||||
seed, |
||||
grid_size, |
||||
top_k=top_k, |
||||
is_verbose=True |
||||
) |
||||
|
||||
def run_dalle(): |
||||
generate_image( |
||||
is_mega=True, |
||||
text='rich ducks playing poker', |
||||
seed=0, |
||||
grid_size=3, |
||||
top_k=256, |
||||
image_path='generated', |
||||
models_root='pretrained', |
||||
fp16=True, |
||||
) |
||||
|
||||
if __name__ == '__main__': |
||||
times = [] |
||||
|
||||
print('Disregarding first two tests...') |
||||
run_dalle() |
||||
run_dalle() |
||||
|
||||
print('Running dalle', NUM_TESTS, 'times...') |
||||
|
||||
for i in range(NUM_TESTS): |
||||
start = time.time() |
||||
|
||||
print() |
||||
print('Running test', i+1, '/', NUM_TESTS, '...') |
||||
|
||||
run_dalle() |
||||
|
||||
duration = time.time() - start |
||||
print(' Completed in', duration, 's.') |
||||
times.append(duration) |
||||
|
||||
print() |
||||
print() |
||||
print('Run times:') |
||||
for t in times: |
||||
print(t) |
||||
|
||||
average = sum(times) / NUM_TESTS |
||||
print() |
||||
print('Average:', average, '| Max:', max(times), '| Min:', min(times)) |
||||
|
Loading…
Reference in new issue