feat: add ColorJitter augmentation to training data
Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
52
train.py
52
train.py
@@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
|
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler, Dataset
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
import os
|
import os
|
||||||
@@ -40,6 +40,26 @@ def check_crop():
|
|||||||
print(f"Debug image saved to '{output_path}'.")
|
print(f"Debug image saved to '{output_path}'.")
|
||||||
|
|
||||||
|
|
||||||
|
class TransformedSubset(Dataset):
|
||||||
|
"""
|
||||||
|
A wrapper for a Subset that allows applying a transform.
|
||||||
|
This is necessary because a transform cannot be applied to a Subset directly.
|
||||||
|
"""
|
||||||
|
def __init__(self, subset, transform=None):
|
||||||
|
self.subset = subset
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# The subset returns the data from the original dataset (img, label)
|
||||||
|
img, label = self.subset[index]
|
||||||
|
if self.transform:
|
||||||
|
img = self.transform(img)
|
||||||
|
return img, label
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.subset)
|
||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
# --- Hyperparameters and Configuration ---
|
# --- Hyperparameters and Configuration ---
|
||||||
DATA_DIR = 'data/labelled'
|
DATA_DIR = 'data/labelled'
|
||||||
@@ -49,25 +69,37 @@ def train_model():
|
|||||||
LEARNING_RATE = 0.001
|
LEARNING_RATE = 0.001
|
||||||
|
|
||||||
# --- Data Preparation ---
|
# --- Data Preparation ---
|
||||||
# Define transforms
|
# Define separate transforms for training (with augmentation) and validation (without)
|
||||||
data_transforms = transforms.Compose([
|
train_transforms = transforms.Compose([
|
||||||
|
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
|
||||||
|
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
|
||||||
|
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
val_transforms = transforms.Compose([
|
||||||
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
|
CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT),
|
||||||
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
|
transforms.Resize((RESIZE_DIM, RESIZE_DIM)),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
])
|
])
|
||||||
|
|
||||||
# Load dataset with ImageFolder
|
# Load dataset without transforms, as they will be applied to subsets
|
||||||
full_dataset = datasets.ImageFolder(DATA_DIR, transform=data_transforms)
|
untransformed_dataset = datasets.ImageFolder(DATA_DIR)
|
||||||
|
|
||||||
# Split into training and validation sets
|
# Split into training and validation sets
|
||||||
train_size = int(0.8 * len(full_dataset))
|
train_size = int(0.8 * len(untransformed_dataset))
|
||||||
val_size = len(full_dataset) - train_size
|
val_size = len(untransformed_dataset) - train_size
|
||||||
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
|
train_subset, val_subset = random_split(untransformed_dataset, [train_size, val_size])
|
||||||
|
|
||||||
|
# Apply the respective transforms to the subsets using our wrapper
|
||||||
|
train_dataset = TransformedSubset(train_subset, transform=train_transforms)
|
||||||
|
val_dataset = TransformedSubset(val_subset, transform=val_transforms)
|
||||||
|
|
||||||
# --- Handle Class Imbalance ---
|
# --- Handle Class Imbalance ---
|
||||||
# Get labels for training set
|
# Get labels for training set from the subset indices
|
||||||
train_labels = [full_dataset.targets[i] for i in train_dataset.indices]
|
train_labels = [untransformed_dataset.targets[i] for i in train_subset.indices]
|
||||||
|
|
||||||
# Get class counts
|
# Get class counts
|
||||||
class_counts = torch.bincount(torch.tensor(train_labels))
|
class_counts = torch.bincount(torch.tensor(train_labels))
|
||||||
|
Reference in New Issue
Block a user