diff --git a/train.py b/train.py index 15a73b7..2af723d 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.optim as optim -from torch.utils.data import DataLoader, random_split, WeightedRandomSampler +from torch.utils.data import DataLoader, random_split, WeightedRandomSampler, Dataset from torchvision import datasets, transforms from PIL import Image, ImageDraw import os @@ -40,6 +40,26 @@ def check_crop(): print(f"Debug image saved to '{output_path}'.") +class TransformedSubset(Dataset): + """ + A wrapper for a Subset that allows applying a transform. + This is necessary because a transform cannot be applied to a Subset directly. + """ + def __init__(self, subset, transform=None): + self.subset = subset + self.transform = transform + + def __getitem__(self, index): + # The subset returns the data from the original dataset (img, label) + img, label = self.subset[index] + if self.transform: + img = self.transform(img) + return img, label + + def __len__(self): + return len(self.subset) + + def train_model(): # --- Hyperparameters and Configuration --- DATA_DIR = 'data/labelled' @@ -49,25 +69,37 @@ def train_model(): LEARNING_RATE = 0.001 # --- Data Preparation --- - # Define transforms - data_transforms = transforms.Compose([ + # Define separate transforms for training (with augmentation) and validation (without) + train_transforms = transforms.Compose([ + CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT), + transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2), + transforms.Resize((RESIZE_DIM, RESIZE_DIM)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + val_transforms = transforms.Compose([ CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT), transforms.Resize((RESIZE_DIM, RESIZE_DIM)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) - # Load dataset with ImageFolder - full_dataset = datasets.ImageFolder(DATA_DIR, transform=data_transforms) + # Load dataset without transforms, as they will be applied to subsets + untransformed_dataset = datasets.ImageFolder(DATA_DIR) # Split into training and validation sets - train_size = int(0.8 * len(full_dataset)) - val_size = len(full_dataset) - train_size - train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) + train_size = int(0.8 * len(untransformed_dataset)) + val_size = len(untransformed_dataset) - train_size + train_subset, val_subset = random_split(untransformed_dataset, [train_size, val_size]) + + # Apply the respective transforms to the subsets using our wrapper + train_dataset = TransformedSubset(train_subset, transform=train_transforms) + val_dataset = TransformedSubset(val_subset, transform=val_transforms) # --- Handle Class Imbalance --- - # Get labels for training set - train_labels = [full_dataset.targets[i] for i in train_dataset.indices] + # Get labels for training set from the subset indices + train_labels = [untransformed_dataset.targets[i] for i in train_subset.indices] # Get class counts class_counts = torch.bincount(torch.tensor(train_labels))