import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split, WeightedRandomSampler, Dataset from torchvision import datasets, transforms from PIL import Image, ImageDraw import os import argparse from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH, TRIANGLE_CROP_HEIGHT, RESIZE_DIM) def check_crop(): """Saves a sample cropped image for debugging purposes and exits.""" # Find a sample image from your dataset SAMPLE_IMAGE_DIR = 'data/labelled/open' if not os.path.isdir(SAMPLE_IMAGE_DIR) or not os.listdir(SAMPLE_IMAGE_DIR): print(f"Error: Cannot find sample image in '{SAMPLE_IMAGE_DIR}'.") print("Please ensure the directory exists and contains images.") return sample_image_name = os.listdir(SAMPLE_IMAGE_DIR)[0] sample_image_path = os.path.join(SAMPLE_IMAGE_DIR, sample_image_name) print(f"Creating debug crop from image: {sample_image_path}") # Load the image image = Image.open(sample_image_path) # Create the transform cropper = CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT) # Apply the transform cropped_image = cropper(image) # Save the result output_path = "cropped_debug_output.png" cropped_image.save(output_path) print(f"Debug image saved to '{output_path}'.") class TransformedSubset(Dataset): """ A wrapper for a Subset that allows applying a transform. This is necessary because a transform cannot be applied to a Subset directly. """ def __init__(self, subset, transform=None): self.subset = subset self.transform = transform def __getitem__(self, index): # The subset returns the data from the original dataset (img, label) img, label = self.subset[index] if self.transform: img = self.transform(img) return img, label def __len__(self): return len(self.subset) def train_model(): # --- Hyperparameters and Configuration --- DATA_DIR = 'data/labelled' MODEL_SAVE_PATH = 'garage_door_cnn.pth' NUM_EPOCHS = 10 BATCH_SIZE = 32 LEARNING_RATE = 0.001 WEIGHT_DECAY = 1e-5 # L2 regularization # --- Data Preparation --- # Define separate transforms for training (with augmentation) and validation (without) train_transforms = transforms.Compose([ CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT), transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2), transforms.Resize((RESIZE_DIM, RESIZE_DIM)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transforms = transforms.Compose([ CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT), transforms.Resize((RESIZE_DIM, RESIZE_DIM)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load dataset without transforms, as they will be applied to subsets untransformed_dataset = datasets.ImageFolder(DATA_DIR) # Split into training and validation sets train_size = int(0.8 * len(untransformed_dataset)) val_size = len(untransformed_dataset) - train_size train_subset, val_subset = random_split(untransformed_dataset, [train_size, val_size]) # Apply the respective transforms to the subsets using our wrapper train_dataset = TransformedSubset(train_subset, transform=train_transforms) val_dataset = TransformedSubset(val_subset, transform=val_transforms) # --- Handle Class Imbalance --- # Get labels for training set from the subset indices train_labels = [untransformed_dataset.targets[i] for i in train_subset.indices] # Get class counts class_counts = torch.bincount(torch.tensor(train_labels)) # Compute weight for each class (inverse of count) class_weights = 1. / class_counts.float() # Assign a weight to each sample in the training set sample_weights = torch.tensor([class_weights[label] for label in train_labels]) # Create a WeightedRandomSampler to balance the classes during training sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True) # The sampler will handle shuffling, so shuffle must be False for the DataLoader train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) # --- Model, Loss, Optimizer --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") model = GarageDoorCNN(resize_dim=RESIZE_DIM).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) # --- Training Loop --- print("Starting training...") for epoch in range(NUM_EPOCHS): model.train() running_loss = 0.0 for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) epoch_loss = running_loss / len(train_dataset) print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Training Loss: {epoch_loss:.4f}") # --- Validation Loop --- model.eval() val_loss = 0.0 corrects = 0 with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() * inputs.size(0) _, preds = torch.max(outputs, 1) corrects += torch.sum(preds == labels.data) val_epoch_loss = val_loss / len(val_dataset) val_epoch_acc = corrects.double() / len(val_dataset) print(f"Validation Loss: {val_epoch_loss:.4f}, Accuracy: {val_epoch_acc:.4f}") # --- Save the trained model --- torch.save(model.state_dict(), MODEL_SAVE_PATH) print(f"Model saved to {MODEL_SAVE_PATH}") if __name__ == '__main__': parser = argparse.ArgumentParser(description="Train a CNN for garage door detection or check the image crop.") parser.add_argument('--check-crop', action='store_true', help='Save a sample cropped image and exit.') args = parser.parse_args() if args.check_crop: check_crop() else: # Check if data directory exists if not os.path.isdir('data/labelled/open') or not os.path.isdir('data/labelled/closed'): print("Error: Data directories 'data/open' and 'data/closed' not found.") print("Please create them and place your image snapshots inside.") else: train_model()