fix: handle class imbalance with WeightedRandomSampler
Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
21
train.py
21
train.py
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
|
||||
from torchvision import datasets, transforms
|
||||
from PIL import Image, ImageDraw
|
||||
import os
|
||||
@@ -38,7 +38,24 @@ def train_model():
|
||||
val_size = len(full_dataset) - train_size
|
||||
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
# --- Handle Class Imbalance ---
|
||||
# Get labels for training set
|
||||
train_labels = [full_dataset.targets[i] for i in train_dataset.indices]
|
||||
|
||||
# Get class counts
|
||||
class_counts = torch.bincount(torch.tensor(train_labels))
|
||||
|
||||
# Compute weight for each class (inverse of count)
|
||||
class_weights = 1. / class_counts.float()
|
||||
|
||||
# Assign a weight to each sample in the training set
|
||||
sample_weights = torch.tensor([class_weights[label] for label in train_labels])
|
||||
|
||||
# Create a WeightedRandomSampler to balance the classes during training
|
||||
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
|
||||
# The sampler will handle shuffling, so shuffle must be False for the DataLoader
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
|
||||
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
# --- Model, Loss, Optimizer ---
|
||||
|
Reference in New Issue
Block a user