采样策略可分为以下情况:
case0:Over sampling &Under sampling ,即对类别多的进行降采样,对类别少的进行重采样
case1: Over sampling 对类别少的进行重采样,采样后的每类样本数与最多的那一类一致
case2:Under sampling 对类别多的进行降采样,采样后的每类样本数与最少的那一类一值
# 计算权重概率代码
# lables是每张图像的类别list [0,0,1,1,0,1,1,0,0,1]
classes, class_sample_count = np.unique(labels, return_counts=True)
# classes = np.unique(labels)
weights = np.zeros(len(labels))
for c in classes:
freq_c = np.sum(np.array(labels) == c) # 该类的数量
weights[np.array(labels) == c] = len(labels) / freq_c # 该类的概率
# weights[np.array(labels) == c] = 1. / freq_c
weights = list(weights)
使用代码
# 有点类似欠采样,接近1:1, 少的多采样,多的少采样
train_nums = len(train_dataset) # 所有类别的数量
# 过采样, 每类数量过采样到最大类别的数量上
class_sample_count = np.array(list(train_dataset.class_sample_count_dict.values())) #
class_sample_count_max = class_sample_count.max() # 类别中数量最多的数量
class_nums = len(list(train_dataset.class_sample_count_dict.keys())) # 类别数
train_nums = int(class_sample_count_max * class_nums) # 总的数量
sampler_weights = train_dataset.weights
train_sampler = torch.utils.data.WeightedRandomSampler(weights=sampler_weights, num_samples=train_nums,
replacement=True)
# 使用
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
# 验证是否进行类别平衡
label_nums = [0, 0] # 二分类
for batch_idx, (data, target) in enumerate(train_loader):
for target_i in target:
lable_id = target_i.item()
label_nums[lable_id] += 1
print("dddd", lable_id)
得到label_nums : [8137, 7966],近似为1:1 ,之前是[ 5865 10238],接近 1:1.7
https://www.cnblogs.com/huadongw/p/6159408.html