refactor: extract model to module and add image sorting script
Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
85
sort.py
85
sort.py
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from model import CropLowerRightTriangle, GarageDoorCNN
|
||||
|
||||
def sort_images():
|
||||
# --- Configuration ---
|
||||
MODEL_PATH = 'garage_door_cnn.pth'
|
||||
SOURCE_DIR = 'data/hourly_photos/'
|
||||
DEST_DIR = 'data/sorted/open/'
|
||||
|
||||
# These must match the parameters used during training
|
||||
TRIANGLE_CROP_WIDTH = 556
|
||||
TRIANGLE_CROP_HEIGHT = 1184
|
||||
RESIZE_DIM = 64
|
||||
|
||||
# The classes are sorted alphabetically by ImageFolder: ['closed', 'open']
|
||||
CLASS_NAMES = ['closed', 'open']
|
||||
TARGET_CLASS = 'open'
|
||||
TARGET_CLASS_IDX = CLASS_NAMES.index(TARGET_CLASS)
|
||||
|
||||
# --- Setup ---
|
||||
# Create destination directory if it doesn't exist
|
||||
os.makedirs(DEST_DIR, exist_ok=True)
|
||||
|
||||
# Set up device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load model
|
||||
model = GarageDoorCNN(resize_dim=RESIZE_DIM)
|
||||
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Define image transforms
|
||||
data_transform = transforms.Compose([
|
||||
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
|
||||
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# --- Process Images ---
|
||||
print(f"Scanning images in {SOURCE_DIR}...")
|
||||
with torch.no_grad():
|
||||
for filename in os.listdir(SOURCE_DIR):
|
||||
file_path = os.path.join(SOURCE_DIR, filename)
|
||||
if os.path.isfile(file_path):
|
||||
try:
|
||||
image = Image.open(file_path).convert('RGB')
|
||||
|
||||
# Apply transformations
|
||||
input_tensor = data_transform(image)
|
||||
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
||||
input_batch = input_batch.to(device)
|
||||
|
||||
# Get model output
|
||||
output = model(input_batch)
|
||||
|
||||
# Get probabilities and prediction
|
||||
probabilities = F.softmax(output, dim=1)
|
||||
confidence, pred_idx = torch.max(probabilities, 1)
|
||||
|
||||
if pred_idx.item() == TARGET_CLASS_IDX:
|
||||
print(f"Found 'open' image: {file_path} with confidence: {confidence.item():.4f}")
|
||||
# Copy file
|
||||
shutil.copy(file_path, os.path.join(DEST_DIR, filename))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Could not process file {file_path}: {e}")
|
||||
|
||||
print("Sorting complete.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
if not os.path.exists('garage_door_cnn.pth'):
|
||||
print("Error: Model file 'garage_door_cnn.pth' not found. Please run train.py first.")
|
||||
elif not os.path.isdir('data/hourly_photos'):
|
||||
print("Error: Source directory 'data/hourly_photos' not found.")
|
||||
else:
|
||||
sort_images()
|
||||
|
Reference in New Issue
Block a user