From 078c8937703009127dfe97e67576000b79c8a9e6 Mon Sep 17 00:00:00 2001 From: Tanner Collin Date: Thu, 31 Jul 2025 16:50:19 -0600 Subject: [PATCH] fix: handle class imbalance with WeightedRandomSampler Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) --- train.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 053644f..a42caa7 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.optim as optim -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader, random_split, WeightedRandomSampler from torchvision import datasets, transforms from PIL import Image, ImageDraw import os @@ -38,7 +38,24 @@ def train_model(): val_size = len(full_dataset) - train_size train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) - train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) + # --- Handle Class Imbalance --- + # Get labels for training set + train_labels = [full_dataset.targets[i] for i in train_dataset.indices] + + # Get class counts + class_counts = torch.bincount(torch.tensor(train_labels)) + + # Compute weight for each class (inverse of count) + class_weights = 1. / class_counts.float() + + # Assign a weight to each sample in the training set + sample_weights = torch.tensor([class_weights[label] for label in train_labels]) + + # Create a WeightedRandomSampler to balance the classes during training + sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True) + + # The sampler will handle shuffling, so shuffle must be False for the DataLoader + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) # --- Model, Loss, Optimizer ---