import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import os import shutil from model import CropLowerRightTriangle, GarageDoorCNN def sort_images(): # --- Configuration --- MODEL_PATH = 'garage_door_cnn.pth' SOURCE_DIR = 'data/hourly_photos/' DEST_DIR = 'data/sorted/open/' # These must match the parameters used during training TRIANGLE_CROP_WIDTH = 556 TRIANGLE_CROP_HEIGHT = 1184 RESIZE_DIM = 64 # The classes are sorted alphabetically by ImageFolder: ['closed', 'open'] CLASS_NAMES = ['closed', 'open'] TARGET_CLASS = 'open' TARGET_CLASS_IDX = CLASS_NAMES.index(TARGET_CLASS) # --- Setup --- # Create destination directory if it doesn't exist os.makedirs(DEST_DIR, exist_ok=True) # Set up device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load model model = GarageDoorCNN(resize_dim=RESIZE_DIM) model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.to(device) model.eval() # Define image transforms data_transform = transforms.Compose([ CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT), transforms.Resize((RESIZE_DIM, RESIZE_DIM)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # --- Process Images --- print(f"Scanning images in {SOURCE_DIR}...") with torch.no_grad(): for filename in os.listdir(SOURCE_DIR): file_path = os.path.join(SOURCE_DIR, filename) if os.path.isfile(file_path): try: image = Image.open(file_path).convert('RGB') # Apply transformations input_tensor = data_transform(image) input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model input_batch = input_batch.to(device) # Get model output output = model(input_batch) # Get probabilities and prediction probabilities = F.softmax(output, dim=1) confidence, pred_idx = torch.max(probabilities, 1) if pred_idx.item() == TARGET_CLASS_IDX: print(f"Found 'open' image: {file_path} with confidence: {confidence.item():.4f}") # Copy file shutil.copy(file_path, os.path.join(DEST_DIR, filename)) except Exception as e: print(f"Could not process file {file_path}: {e}") print("Sorting complete.") if __name__ == '__main__': if not os.path.exists('garage_door_cnn.pth'): print("Error: Model file 'garage_door_cnn.pth' not found. Please run train.py first.") elif not os.path.isdir('data/hourly_photos'): print("Error: Source directory 'data/hourly_photos' not found.") else: sort_images()