Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
183 lines
7.0 KiB
Python
183 lines
7.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler, Dataset
|
|
from torchvision import datasets, transforms
|
|
from PIL import Image, ImageDraw
|
|
import os
|
|
import argparse
|
|
|
|
from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH,
|
|
TRIANGLE_CROP_HEIGHT, RESIZE_DIM)
|
|
|
|
|
|
def check_crop():
|
|
"""Saves a sample cropped image for debugging purposes and exits."""
|
|
# Find a sample image from your dataset
|
|
SAMPLE_IMAGE_DIR = 'data/labelled/open'
|
|
if not os.path.isdir(SAMPLE_IMAGE_DIR) or not os.listdir(SAMPLE_IMAGE_DIR):
|
|
print(f"Error: Cannot find sample image in '{SAMPLE_IMAGE_DIR}'.")
|
|
print("Please ensure the directory exists and contains images.")
|
|
return
|
|
|
|
sample_image_name = os.listdir(SAMPLE_IMAGE_DIR)[0]
|
|
sample_image_path = os.path.join(SAMPLE_IMAGE_DIR, sample_image_name)
|
|
|
|
print(f"Creating debug crop from image: {sample_image_path}")
|
|
|
|
# Load the image
|
|
image = Image.open(sample_image_path)
|
|
|
|
# Create the transform
|
|
cropper = CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT)
|
|
|
|
# Apply the transform
|
|
cropped_image = cropper(image)
|
|
|
|
# Save the result
|
|
output_path = "cropped_debug_output.png"
|
|
cropped_image.save(output_path)
|
|
print(f"Debug image saved to '{output_path}'.")
|
|
|
|
|
|
class TransformedSubset(Dataset):
|
|
"""
|
|
A wrapper for a Subset that allows applying a transform.
|
|
This is necessary because a transform cannot be applied to a Subset directly.
|
|
"""
|
|
def __init__(self, subset, transform=None):
|
|
self.subset = subset
|
|
self.transform = transform
|
|
|
|
def __getitem__(self, index):
|
|
# The subset returns the data from the original dataset (img, label)
|
|
img, label = self.subset[index]
|
|
if self.transform:
|
|
img = self.transform(img)
|
|
return img, label
|
|
|
|
def __len__(self):
|
|
return len(self.subset)
|
|
|
|
|
|
def train_model():
|
|
# --- Hyperparameters and Configuration ---
|
|
DATA_DIR = 'data/labelled'
|
|
MODEL_SAVE_PATH = 'garage_door_cnn.pth'
|
|
NUM_EPOCHS = 10
|
|
BATCH_SIZE = 32
|
|
LEARNING_RATE = 0.001
|
|
WEIGHT_DECAY = 1e-5 # L2 regularization
|
|
|
|
# --- Data Preparation ---
|
|
# Define separate transforms for training (with augmentation) and validation (without)
|
|
train_transforms = transforms.Compose([
|
|
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
|
|
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
|
|
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
val_transforms = transforms.Compose([
|
|
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
|
|
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# Load dataset without transforms, as they will be applied to subsets
|
|
untransformed_dataset = datasets.ImageFolder(DATA_DIR)
|
|
|
|
# Split into training and validation sets
|
|
train_size = int(0.8 * len(untransformed_dataset))
|
|
val_size = len(untransformed_dataset) - train_size
|
|
train_subset, val_subset = random_split(untransformed_dataset, [train_size, val_size])
|
|
|
|
# Apply the respective transforms to the subsets using our wrapper
|
|
train_dataset = TransformedSubset(train_subset, transform=train_transforms)
|
|
val_dataset = TransformedSubset(val_subset, transform=val_transforms)
|
|
|
|
# --- Handle Class Imbalance ---
|
|
# Get labels for training set from the subset indices
|
|
train_labels = [untransformed_dataset.targets[i] for i in train_subset.indices]
|
|
|
|
# Get class counts
|
|
class_counts = torch.bincount(torch.tensor(train_labels))
|
|
|
|
# Compute weight for each class (inverse of count)
|
|
class_weights = 1. / class_counts.float()
|
|
|
|
# Assign a weight to each sample in the training set
|
|
sample_weights = torch.tensor([class_weights[label] for label in train_labels])
|
|
|
|
# Create a WeightedRandomSampler to balance the classes during training
|
|
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
|
|
|
# The sampler will handle shuffling, so shuffle must be False for the DataLoader
|
|
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
|
|
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
|
|
|
# --- Model, Loss, Optimizer ---
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
model = GarageDoorCNN(resize_dim=RESIZE_DIM).to(device)
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
|
|
|
# --- Training Loop ---
|
|
print("Starting training...")
|
|
for epoch in range(NUM_EPOCHS):
|
|
model.train()
|
|
running_loss = 0.0
|
|
for inputs, labels in train_loader:
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item() * inputs.size(0)
|
|
|
|
epoch_loss = running_loss / len(train_dataset)
|
|
print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Training Loss: {epoch_loss:.4f}")
|
|
|
|
# --- Validation Loop ---
|
|
model.eval()
|
|
val_loss = 0.0
|
|
corrects = 0
|
|
with torch.no_grad():
|
|
for inputs, labels in val_loader:
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
val_loss += loss.item() * inputs.size(0)
|
|
_, preds = torch.max(outputs, 1)
|
|
corrects += torch.sum(preds == labels.data)
|
|
|
|
val_epoch_loss = val_loss / len(val_dataset)
|
|
val_epoch_acc = corrects.double() / len(val_dataset)
|
|
print(f"Validation Loss: {val_epoch_loss:.4f}, Accuracy: {val_epoch_acc:.4f}")
|
|
|
|
# --- Save the trained model ---
|
|
torch.save(model.state_dict(), MODEL_SAVE_PATH)
|
|
print(f"Model saved to {MODEL_SAVE_PATH}")
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description="Train a CNN for garage door detection or check the image crop.")
|
|
parser.add_argument('--check-crop', action='store_true', help='Save a sample cropped image and exit.')
|
|
args = parser.parse_args()
|
|
|
|
if args.check_crop:
|
|
check_crop()
|
|
else:
|
|
# Check if data directory exists
|
|
if not os.path.isdir('data/labelled/open') or not os.path.isdir('data/labelled/closed'):
|
|
print("Error: Data directories 'data/open' and 'data/closed' not found.")
|
|
print("Please create them and place your image snapshots inside.")
|
|
else:
|
|
train_model()
|