torch.utils.data.WeightedRandomSampler采样

采样策略可分为以下情况:

case0:Over sampling &Under sampling ,即对类别多的进行降采样,对类别少的进行重采样
case1: Over sampling 对类别少的进行重采样,采样后的每类样本数与最多的那一类一致
case2:Under sampling 对类别多的进行降采样,采样后的每类样本数与最少的那一类一值 
# 计算权重概率代码
# lables是每张图像的类别list [0,0,1,1,0,1,1,0,0,1]
 classes, class_sample_count = np.unique(labels, return_counts=True)
        # classes = np.unique(labels)
        weights = np.zeros(len(labels))
        for c in classes:
            freq_c = np.sum(np.array(labels) == c)  # 该类的数量
            weights[np.array(labels) == c] = len(labels) / freq_c # 该类的概率
            # weights[np.array(labels) == c] = 1. / freq_c 
        weights = list(weights)

使用代码

#  有点类似欠采样,接近1:1, 少的多采样,多的少采样
train_nums = len(train_dataset)  # 所有类别的数量

# 过采样, 每类数量过采样到最大类别的数量上
class_sample_count = np.array(list(train_dataset.class_sample_count_dict.values()))  #
class_sample_count_max = class_sample_count.max()  # 类别中数量最多的数量
class_nums = len(list(train_dataset.class_sample_count_dict.keys()))  # 类别数
train_nums = int(class_sample_count_max * class_nums)  # 总的数量

sampler_weights = train_dataset.weights
train_sampler = torch.utils.data.WeightedRandomSampler(weights=sampler_weights, num_samples=train_nums,
                                                               replacement=True)
 # 使用
 train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)
  # 验证是否进行类别平衡
    label_nums = [0, 0]  # 二分类
    for batch_idx, (data, target) in enumerate(train_loader):
        for target_i in target:
            lable_id = target_i.item()
            label_nums[lable_id] += 1
            print("dddd", lable_id)  
得到label_nums : [8137, 7966],近似为1:1 ,之前是[ 5865 10238],接近 1:1.7
https://www.cnblogs.com/huadongw/p/6159408.html
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>