Files
doormind/sort.py
2025-07-31 18:44:05 -06:00

91 lines
3.4 KiB
Python

import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import shutil
import random
from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH,
TRIANGLE_CROP_HEIGHT, RESIZE_DIM)
def sort_images():
# --- Configuration ---
MODEL_PATH = 'garage_door_cnn.pth'
SOURCE_DIR = 'data/hourly_photos/'
DEST_DIR = 'data/sorted/open/'
CONFIDENCE_THRESHOLD = 0.80 # Only copy if confidence is over this value
# The classes are sorted alphabetically by ImageFolder: ['closed', 'open']
CLASS_NAMES = ['closed', 'open']
TARGET_CLASS = 'open'
TARGET_CLASS_IDX = CLASS_NAMES.index(TARGET_CLASS)
# --- Setup ---
# Create destination directory if it doesn't exist
os.makedirs(DEST_DIR, exist_ok=True)
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load model
model = GarageDoorCNN(resize_dim=RESIZE_DIM)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()
# Define image transforms
data_transform = transforms.Compose([
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# --- Process Images ---
print(f"Scanning images in {SOURCE_DIR}...")
# Get and shuffle filenames
filenames = [f for f in os.listdir(SOURCE_DIR) if os.path.isfile(os.path.join(SOURCE_DIR, f))]
random.shuffle(filenames)
total_files = len(filenames)
with torch.no_grad():
for i, filename in enumerate(filenames):
file_path = os.path.join(SOURCE_DIR, filename)
try:
image = Image.open(file_path).convert('RGB')
# Apply transformations
input_tensor = data_transform(image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
input_batch = input_batch.to(device)
# Get model output
output = model(input_batch)
# Get probabilities and prediction
probabilities = F.softmax(output, dim=1)
confidence, pred_idx = torch.max(probabilities, 1)
if pred_idx.item() == TARGET_CLASS_IDX and confidence.item() > CONFIDENCE_THRESHOLD:
progress = f"({i + 1}/{total_files} - {(i + 1) / total_files * 100:.1f}%)"
print(f"{progress} Found 'open' image: {file_path} with confidence: {confidence.item():.4f}")
# Copy file
shutil.copy(file_path, os.path.join(DEST_DIR, filename))
except Exception as e:
print(f"Could not process file {file_path}: {e}")
print("Sorting complete.")
if __name__ == '__main__':
if not os.path.exists('garage_door_cnn.pth'):
print("Error: Model file 'garage_door_cnn.pth' not found. Please run train.py first.")
elif not os.path.isdir('data/hourly_photos'):
print("Error: Source directory 'data/hourly_photos' not found.")
else:
sort_images()