Influence-Balanced Loss 中的Resample策略

 改进的sampler策略

    elif args.train_rule == 'Resample':
        train_sampler = ImbalancedDatasetSampler(dset_train)
class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):

    def __init__(self, dataset, indices=None, num_samples=None):
                
        # if indices is not provided, 
        # all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) \
            if indices is None else indices
            
        # if num_samples is not provided, 
        # draw `len(indices)` samples in each iteration
        self.num_samples = len(self.indices) \
            if num_samples is None else num_samples # 数据集样本个数
            
        # distribution of classes in the dataset 
        label_to_count = [0] * len(np.unique(dataset.targets))
        for idx in self.indices:
            label = self._get_label(dataset, idx)
            label_to_count[label] += 1
            
        beta = 0.9999
        effective_num = 1.0 - np.power(beta, label_to_count)
        per_cls_weights = (1.0 - beta) / np.array(effective_num) #各类别的权重 per_cls_weights: [0.00248924 0.00202661 0.00689909 0.00975834]

        # weight for each sample
        weights = [per_cls_weights[self._get_label(dataset, idx)]
                   for idx in self.indices] # 各样本的权重

        self.weights = torch.DoubleTensor(weights)
        
    def _get_label(self, dataset, idx):
        return dataset.targets[idx]
                
    def __iter__(self):
        return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist())

    def __len__(self):
        return self.num_samples

Class Counts: [410, 506, 146, 103]
per_cls_weights: [0.00248924 0.00202661 0.00689909 0.00975834]

0.00248924*410+0.00202661*506+0.00689909*146+103*0.00975834=4.05842922

普通sampler

继承了sampler类,然后重新为数据集中的各样本分配权重。

如果使用的是普通的采样器(sampler),例如 PyTorch 中的 RandomSampler 或简单的顺序采样,每个样本通常被赋予等权重。这意味着在抽样过程中,每个样本被选中的概率是相等的。

在这种情况下,假设数据集中有 𝑁个样本,那么每个样本被选中的概率和权重都是 1/𝑁​。这种方式不考虑数据集中可能存在的类别不平衡问题,每个样本被选取的机会完全相同。

例如,如果你有一个包含 100 个样本的数据集,并使用普通的采样器进行随机抽样,则每个样本被选中的概率都是 1%。这种采样方式简单且常用,但在处理类别极度不平衡的数据集时可能不够有效,因为它可能导致模型对多数类过拟合,而忽视了少数类。

ImbalancedDatasetSampler的采样策略的公式和CBReweight的公式差不多

两者都试图通过为每个类别的样本分配不同的权重来解决类别不平衡问题,但应用的场景和具体实现有所不同:

  • ImbalancedDatasetSampler:影响的是数据采样过程,通过改变数据输入模型的方式来达成类别平衡。
  • CBReweight:直接作用于模型的损失函数,通过改变损失计算方式来强调少数类的重要性。

尽管两者策略相似,但具体实现和影响的环节(数据层面 vs. 模型训练层面)有所区别。

ImbalancedDatasetSampler最后会将整个数据集的每个样本的权重列表送入官方写好的sampler里(继承普通的sampler类),CBReweight会将每个类的权重列表送入官方写好的代码里(交叉熵损失)

  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值