fix: handle class imbalance with WeightedRandomSampler

Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
2025-07-31 16:50:19 -06:00
parent 5f18d8bce2
commit 078c893770

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torchvision import datasets, transforms
from PIL import Image, ImageDraw
import os
@@ -38,7 +38,24 @@ def train_model():
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# --- Handle Class Imbalance ---
# Get labels for training set
train_labels = [full_dataset.targets[i] for i in train_dataset.indices]
# Get class counts
class_counts = torch.bincount(torch.tensor(train_labels))
# Compute weight for each class (inverse of count)
class_weights = 1. / class_counts.float()
# Assign a weight to each sample in the training set
sample_weights = torch.tensor([class_weights[label] for label in train_labels])
# Create a WeightedRandomSampler to balance the classes during training
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
# The sampler will handle shuffling, so shuffle must be False for the DataLoader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
# --- Model, Loss, Optimizer ---