diff --git a/sort.py b/sort.py index c246068..3bc5097 100644 --- a/sort.py +++ b/sort.py @@ -4,6 +4,7 @@ from torchvision import transforms from PIL import Image import os import shutil +import random from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH, TRIANGLE_CROP_HEIGHT, RESIZE_DIM) @@ -45,11 +46,16 @@ def sort_images(): # --- Process Images --- print(f"Scanning images in {SOURCE_DIR}...") + + # Get and shuffle filenames + filenames = [f for f in os.listdir(SOURCE_DIR) if os.path.isfile(os.path.join(SOURCE_DIR, f))] + random.shuffle(filenames) + total_files = len(filenames) + with torch.no_grad(): - for filename in os.listdir(SOURCE_DIR): + for i, filename in enumerate(filenames): file_path = os.path.join(SOURCE_DIR, filename) - if os.path.isfile(file_path): - try: + try: image = Image.open(file_path).convert('RGB') # Apply transformations @@ -65,12 +71,13 @@ def sort_images(): confidence, pred_idx = torch.max(probabilities, 1) if pred_idx.item() == TARGET_CLASS_IDX and confidence.item() > CONFIDENCE_THRESHOLD: - print(f"Found 'open' image: {file_path} with confidence: {confidence.item():.4f}") + progress = f"({i + 1}/{total_files} - {(i + 1) / total_files * 100:.1f}%)" + print(f"{progress} Found 'open' image: {file_path} with confidence: {confidence.item():.4f}") # Copy file shutil.copy(file_path, os.path.join(DEST_DIR, filename)) - except Exception as e: - print(f"Could not process file {file_path}: {e}") + except Exception as e: + print(f"Could not process file {file_path}: {e}") print("Sorting complete.")