91 lines
3.4 KiB
Python
91 lines
3.4 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
import os
|
|
import shutil
|
|
import random
|
|
|
|
from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH,
|
|
TRIANGLE_CROP_HEIGHT, RESIZE_DIM)
|
|
|
|
def sort_images():
|
|
# --- Configuration ---
|
|
MODEL_PATH = 'garage_door_cnn.pth'
|
|
SOURCE_DIR = 'data/hourly_photos/'
|
|
DEST_DIR = 'data/sorted/open/'
|
|
CONFIDENCE_THRESHOLD = 0.80 # Only copy if confidence is over this value
|
|
|
|
|
|
# The classes are sorted alphabetically by ImageFolder: ['closed', 'open']
|
|
CLASS_NAMES = ['closed', 'open']
|
|
TARGET_CLASS = 'open'
|
|
TARGET_CLASS_IDX = CLASS_NAMES.index(TARGET_CLASS)
|
|
|
|
# --- Setup ---
|
|
# Create destination directory if it doesn't exist
|
|
os.makedirs(DEST_DIR, exist_ok=True)
|
|
|
|
# Set up device
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
# Load model
|
|
model = GarageDoorCNN(resize_dim=RESIZE_DIM)
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
# Define image transforms
|
|
data_transform = transforms.Compose([
|
|
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
|
|
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# --- Process Images ---
|
|
print(f"Scanning images in {SOURCE_DIR}...")
|
|
|
|
# Get and shuffle filenames
|
|
filenames = [f for f in os.listdir(SOURCE_DIR) if os.path.isfile(os.path.join(SOURCE_DIR, f))]
|
|
random.shuffle(filenames)
|
|
total_files = len(filenames)
|
|
|
|
with torch.no_grad():
|
|
for i, filename in enumerate(filenames):
|
|
file_path = os.path.join(SOURCE_DIR, filename)
|
|
try:
|
|
image = Image.open(file_path).convert('RGB')
|
|
|
|
# Apply transformations
|
|
input_tensor = data_transform(image)
|
|
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
|
input_batch = input_batch.to(device)
|
|
|
|
# Get model output
|
|
output = model(input_batch)
|
|
|
|
# Get probabilities and prediction
|
|
probabilities = F.softmax(output, dim=1)
|
|
confidence, pred_idx = torch.max(probabilities, 1)
|
|
|
|
if pred_idx.item() == TARGET_CLASS_IDX and confidence.item() > CONFIDENCE_THRESHOLD:
|
|
progress = f"({i + 1}/{total_files} - {(i + 1) / total_files * 100:.1f}%)"
|
|
print(f"{progress} Found 'open' image: {file_path} with confidence: {confidence.item():.4f}")
|
|
# Copy file
|
|
shutil.copy(file_path, os.path.join(DEST_DIR, filename))
|
|
|
|
except Exception as e:
|
|
print(f"Could not process file {file_path}: {e}")
|
|
|
|
print("Sorting complete.")
|
|
|
|
if __name__ == '__main__':
|
|
if not os.path.exists('garage_door_cnn.pth'):
|
|
print("Error: Model file 'garage_door_cnn.pth' not found. Please run train.py first.")
|
|
elif not os.path.isdir('data/hourly_photos'):
|
|
print("Error: Source directory 'data/hourly_photos' not found.")
|
|
else:
|
|
sort_images()
|