数据不平衡, pytorch——WeightedRandomSampler

官网

参考官网:classtorch.utils.data.WeightedRandomSampler

在这里插入图片描述

参数解释如下:Parameters

    weights (sequence) – a sequence of weights, not necessary summing up to one

    num_samples (int) – number of samples to draw

    replacement (bool) – if True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.

    generator (Generator) – Generator used in sampling.

1.weights参数对应的是“样本”的权重而不是“类别的权重”。 也就是说:有一千个样本,weight的数值就有一千个,因此有 len(weight)= 样本数
2.num_sampler用于控制生成的个数;
3.replacement参数依旧用于控制采样是否是有放回的;(如果重复取,就true,否则就false,默认是true)
4.generator (Generator) – Generator used in sampling. 采样时用的生成器(不重要)

举例理解weight

说到这里还是有点迷惑,下面举例说明:
本文不采用官网的举例,而是借鉴一下这篇 博客中的举例说明问题
如下图,weight是一些tensor,代表每个位置的样本对应的权重,WeightedRandomSampler(weights, 6, True) ,表示 按照weight给出的权重,生成六个索引,而且是重复取样。
在这里插入图片描述
从输出可以看出,位置 [1] = 10 由于权重较大,被采样的次数较多,位置[0]由于权重为0所以没有被采样到,其余位置权重低所以都仅仅被采样一次。

到这里应该明白了WeightedRandomSampler(, , )各参数的含义,
那么如何获得某个数据集的权重呢?

weight = [ ] 里面每一项代表该样本种类占总样本的倒数。

例如: 数据集 animal = [ cat, cat, dog, dog, dog]
cat有两个,dog有三个,
先计算每种动物的占比,cat_count = 2/5 = 0.4 dog_count = 3/5 = 0.6
再计算count的倒数,也就是占比的倒数,这个数值就是weight
cat_weight = 1/count = 1/0.4 = 2.5 dog_weight = 1/count = 1/0.6 = 1.67
那么weight 列表就可以写作:weight = [2.5, 2.5, 1.67, 1.67, 1.67]
至此,weight的来源陈述结束

在自己的训练中使用weight

那么weight要如何在训练中使用呢?

# Weights for sampler into network, fixes class imbalance

weights_train = np.array(train_csv['0'])
weights_val = np.array(val_csv['0'])

weight_normal_train = len(weights_train) / (float) (np.count_nonzero(weights_train == 0))
weight_af_train = len(weights_train) / (float) (np.count_nonzero(weights_train == 1))

weight_normal_val = len(weights_val) / (float) (np.count_nonzero(weights_val == 0))
weight_af_val = len(weights_val) / (float) (np.count_nonzero(weights_val == 1))

weights_train[weights_train == 0] = weight_normal_train
weights_train[weights_train == 1] = weight_af_train

weights_val[weights_val == 0] = weight_normal_val
weights_val[weights_val == 1] = weight_af_val

weights_train = torch.DoubleTensor(weights_train.astype('float32'))
weights_val = torch.DoubleTensor(weights_val.astype('float32'))

train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights_train, len(weights_train))
val_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights_val, len(weights_val))

print(weight_normal_train, weight_af_train)

结果是:

1.1428436018957346 8.000663570006635

制作数据集接口

# Generate training and validation datasets
train_set = AFDataset('training_data1.csv', 'CTdata1/')
train_generator = DataLoader(train_set, **params, sampler = train_sampler)

val_set = AFDataset('validation_data1.csv', 'CTdata1/')
val_generator = DataLoader(val_set, **params, sampler = val_sampler)

test_set = AFDataset('test_data1.csv', 'CTdata1/')
test_generator = DataLoader(test_set, **params)

在训练的时候,在每个epoch里面写作

 for i, (local_batch, local_labels) in enumerate(train_generator, 0):

写在如下图的位置
在这里插入图片描述

最后有个疑问

在数据量本身就不多的情况下,如果使用这种采样方式,岂不是会丢失原始数据?

除此之外,还有另外的方法

在计算loss的时候考虑 weight,这样就不会丢失原始数据,(基本思路是选择
focalloss损失函数)具体实现待完善下篇~

另外,最重要的一点:
有帮助的话 点个赞 ~~ ⁂((✪⥎✪))⁂

  • 16
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
PyTorch中进行数据清洗的过程通常包括以下几个步骤: 1. 加载数据集:首先,你需要加载你的数据集。这可以通过使用PyTorch提供的数据加载器类(如`torchvision.datasets`)来完成,或者自定义一个数据加载器类。 2. 数据预处理:在加载数据集之后,你可能需要对数据进行一些预处理操作,如裁剪、缩放、标准化等。PyTorch提供了许多预处理函数(如`torchvision.transforms`)来帮助你完成这些操作。 3. 数据转换:一旦数据预处理完成,你可能需要将数据转换为PyTorch所需的张量格式。你可以使用`torch.from_numpy()`将NumPy数组转换为张量,或使用`torch.tensor()`创建一个新的张量。 4. 数据清洗:数据清洗的具体操作取决于你的数据集和任务需求。常见的数据清洗操作包括去除缺失值、处理异常值、删除重复样本等。你可以使用NumPy或Pandas等库来执行这些操作。 5. 数据划分:在数据清洗之后,你可能需要将数据集划分为训练集、验证集和测试集。你可以使用PyTorch提供的数据集拆分函数(如`torch.utils.data.random_split()`)来实现。 6. 数据加载器:最后,你需要创建数据加载器来批量加载和迭代数据PyTorch提供了`torch.utils.data.DataLoader`类,可以帮助你方便地创建数据加载器,并支持批量加载、数据随机化等功能。 以上是一个基本的数据清洗流程,具体的实现细节会根据你的数据集和任务而有所不同。希望对你有所帮助!如果你有任何其他问题,请随时提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值