处理imblance数据
数据集两类:金毛100:哈士奇1
Methods for dealing with imbalanced datasets:
- Oversampling
即不同的数据增强 - Class weighting
给少的数据集更大的权重
loss_fn=nn.CrossEntrophyLoss(weight=torch.tensor([1,50,12...]))
import torch
import torchvision.datasets as datasets
import os
from torch.utils.data import WeightedRandomSampler, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
# Methods for dealing with imbalanced datasets:
# 1. Oversampling
# 2. Class weighting
def get_loader(root_dir, batch_size):
my_transforms = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
]
)
dataset = datasets.ImageFolder(root=root_dir, transform=my_transforms)
class_weights = []
for root, subdir, files in os.walk(root_dir):
if len(files) > 0:
class_weights.append(1/len(files))
sample_weights = [0] * len(dataset)
for idx, (data, label) in enumerate(dataset):
class_weight = class_weights[label]
sample_weights[idx] = class_weight
sampler = WeightedRandomSampler(sample_weights, num_samples=
len(sample_weights), replacement=True)
loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
return loader
def main():
loader = get_loader(root_dir="dataset", batch_size=8)
num_retrievers = 0
num_elkhounds = 0
for epoch in range(10):
for data, labels in loader:
num_retrievers += torch.sum(labels==0)
num_elkhounds += torch.sum(labels==1)
print(num_retrievers)
print(num_elkhounds)
if __name__ == "__main__":
main()