fix: handle class imbalance with WeightedRandomSampler
Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
21
train.py
21
train.py
@@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.utils.data import DataLoader, random_split
|
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
import os
|
import os
|
||||||
@@ -38,7 +38,24 @@ def train_model():
|
|||||||
val_size = len(full_dataset) - train_size
|
val_size = len(full_dataset) - train_size
|
||||||
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
|
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
# --- Handle Class Imbalance ---
|
||||||
|
# Get labels for training set
|
||||||
|
train_labels = [full_dataset.targets[i] for i in train_dataset.indices]
|
||||||
|
|
||||||
|
# Get class counts
|
||||||
|
class_counts = torch.bincount(torch.tensor(train_labels))
|
||||||
|
|
||||||
|
# Compute weight for each class (inverse of count)
|
||||||
|
class_weights = 1. / class_counts.float()
|
||||||
|
|
||||||
|
# Assign a weight to each sample in the training set
|
||||||
|
sample_weights = torch.tensor([class_weights[label] for label in train_labels])
|
||||||
|
|
||||||
|
# Create a WeightedRandomSampler to balance the classes during training
|
||||||
|
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||||
|
|
||||||
|
# The sampler will handle shuffling, so shuffle must be False for the DataLoader
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
|
||||||
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||||
|
|
||||||
# --- Model, Loss, Optimizer ---
|
# --- Model, Loss, Optimizer ---
|
||||||
|
Reference in New Issue
Block a user