feat: add ColorJitter augmentation to training data

Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
2025-07-31 17:09:48 -06:00
parent 89b407c564
commit 77c103b9fe

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler, Dataset
from torchvision import datasets, transforms
from PIL import Image, ImageDraw
import os
@@ -40,6 +40,26 @@ def check_crop():
print(f"Debug image saved to '{output_path}'.")
class TransformedSubset(Dataset):
"""
A wrapper for a Subset that allows applying a transform.
This is necessary because a transform cannot be applied to a Subset directly.
"""
def __init__(self, subset, transform=None):
self.subset = subset
self.transform = transform
def __getitem__(self, index):
# The subset returns the data from the original dataset (img, label)
img, label = self.subset[index]
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.subset)
def train_model():
# --- Hyperparameters and Configuration ---
DATA_DIR = 'data/labelled'
@@ -49,25 +69,37 @@ def train_model():
LEARNING_RATE = 0.001
# --- Data Preparation ---
# Define transforms
data_transforms = transforms.Compose([
# Define separate transforms for training (with augmentation) and validation (without)
train_transforms = transforms.Compose([
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load dataset with ImageFolder
full_dataset = datasets.ImageFolder(DATA_DIR, transform=data_transforms)
# Load dataset without transforms, as they will be applied to subsets
untransformed_dataset = datasets.ImageFolder(DATA_DIR)
# Split into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_size = int(0.8 * len(untransformed_dataset))
val_size = len(untransformed_dataset) - train_size
train_subset, val_subset = random_split(untransformed_dataset, [train_size, val_size])
# Apply the respective transforms to the subsets using our wrapper
train_dataset = TransformedSubset(train_subset, transform=train_transforms)
val_dataset = TransformedSubset(val_subset, transform=val_transforms)
# --- Handle Class Imbalance ---
# Get labels for training set
train_labels = [full_dataset.targets[i] for i in train_dataset.indices]
# Get labels for training set from the subset indices
train_labels = [untransformed_dataset.targets[i] for i in train_subset.indices]
# Get class counts
class_counts = torch.bincount(torch.tensor(train_labels))