feat: add ColorJitter augmentation to training data
Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
52
train.py
52
train.py
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
|
||||
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler, Dataset
|
||||
from torchvision import datasets, transforms
|
||||
from PIL import Image, ImageDraw
|
||||
import os
|
||||
@@ -40,6 +40,26 @@ def check_crop():
|
||||
print(f"Debug image saved to '{output_path}'.")
|
||||
|
||||
|
||||
class TransformedSubset(Dataset):
|
||||
"""
|
||||
A wrapper for a Subset that allows applying a transform.
|
||||
This is necessary because a transform cannot be applied to a Subset directly.
|
||||
"""
|
||||
def __init__(self, subset, transform=None):
|
||||
self.subset = subset
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
# The subset returns the data from the original dataset (img, label)
|
||||
img, label = self.subset[index]
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.subset)
|
||||
|
||||
|
||||
def train_model():
|
||||
# --- Hyperparameters and Configuration ---
|
||||
DATA_DIR = 'data/labelled'
|
||||
@@ -49,25 +69,37 @@ def train_model():
|
||||
LEARNING_RATE = 0.001
|
||||
|
||||
# --- Data Preparation ---
|
||||
# Define transforms
|
||||
data_transforms = transforms.Compose([
|
||||
# Define separate transforms for training (with augmentation) and validation (without)
|
||||
train_transforms = transforms.Compose([
|
||||
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
|
||||
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
|
||||
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
val_transforms = transforms.Compose([
|
||||
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
|
||||
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# Load dataset with ImageFolder
|
||||
full_dataset = datasets.ImageFolder(DATA_DIR, transform=data_transforms)
|
||||
# Load dataset without transforms, as they will be applied to subsets
|
||||
untransformed_dataset = datasets.ImageFolder(DATA_DIR)
|
||||
|
||||
# Split into training and validation sets
|
||||
train_size = int(0.8 * len(full_dataset))
|
||||
val_size = len(full_dataset) - train_size
|
||||
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
|
||||
train_size = int(0.8 * len(untransformed_dataset))
|
||||
val_size = len(untransformed_dataset) - train_size
|
||||
train_subset, val_subset = random_split(untransformed_dataset, [train_size, val_size])
|
||||
|
||||
# Apply the respective transforms to the subsets using our wrapper
|
||||
train_dataset = TransformedSubset(train_subset, transform=train_transforms)
|
||||
val_dataset = TransformedSubset(val_subset, transform=val_transforms)
|
||||
|
||||
# --- Handle Class Imbalance ---
|
||||
# Get labels for training set
|
||||
train_labels = [full_dataset.targets[i] for i in train_dataset.indices]
|
||||
# Get labels for training set from the subset indices
|
||||
train_labels = [untransformed_dataset.targets[i] for i in train_subset.indices]
|
||||
|
||||
# Get class counts
|
||||
class_counts = torch.bincount(torch.tensor(train_labels))
|
||||
|
Reference in New Issue
Block a user