一文读懂Pytorh Sampler

> 本文从DataSet、DataLoader和Sampler的关系出发,介绍Pytorch实现的五种采样,并应用到DataLoader中。

🎏目录

    🎈1 DataSet、DataLoader和Sampler的关系
    🎈2 Sampler
      🎄2.1 SequentialSampler(顺序采样)
      🎄2.2 RandomSampler(随即采样)
      🎄2.3 BatchSampler(批采样)
      🎄2.4 SubsetRandomSampler(子集随机采样)
      🎄2.5 WeightedRandomSampler(加权随机采样)

✨1 DataSet、DataLoader和Sampler的关系

我们知道DataSet建立数据集,本质是读取一张张图像。而DataLoader是将DataSet中的图像一个个取出来,打包成一个个batch。
但是这里存在一个问题,DataLoader从Dataet中是如何取一张张图像的,该问题对我们训练也有影响
假设我们数据集是按照类别放在一起的,那么DataSet的读取的图像也是按照类别放在一起的。此时,如果DataLoader顺序读取打包,则可能出现每个batch中都是同一个类别的图像。这就会影响我们模型的训练效果。

因此需要Sampler决定打包时的读取图像的顺序。这就是三者之间的关系。

✨2 Sampler

Pytorch中实现了五种Sampler:

  1. SequentialSampler(顺序采样)
  2. RandomSampler(随机采样)
  3. WeightedSampler(加权随机采样)
  4. SubsetRandomSampler(子集随机采样)
  5. BatchSampler(批采样)

(其中1,2,5可应用到DataLoader中,第三节详细展开)

🎃 2.1 SequentialSampler(顺序采样)

用于获取数据索引

torch.utils.data.SequentialSampler(
	data_source,
)

参数:

  1. data_source:可迭代数据,一般为数据集

返回:
顺序返回数据集索引

示例:
在这里插入图片描述

🎉 2.2 RandomSampler(随即采样)

用于获取打乱的数据索引

torch.utils.data.SequentialSampler(
	data_source,
	num_samples,
	replacement,
)

参数:

  1. data_source:同上
  2. num_samples:指定采样的数量,默认是所有
  3. replacement:若为True,则表示可以重复采样,即同一个样本可以重复采样,这样可能导致有的样本采样不到。所以此时我们可以设置num_samples来增加采样数量使得每个样本都可能被采样到。

返回:
乱序返回数据集索引
在这里插入图片描述

🎄2.3 BatchSampler(批采样)

BatchSampler将前面的Sampler采样得到的单个的索引值进行合并,当数量等于一个batch大小后就将这一批的索引值返回。(训练时使用的是批量数据)

torch.utils.data.BatchSampler(
	sampler,
	batch_size, 
	drop_last,
)

参数:

  1. sampler:上述两种采样器,即SequentialSampler或RandomSampler
  2. batch_size:batch的大小
  3. drop_last:True或False。drop_last为True时,如果采样得到的数据个数小于batch_size则抛弃本个batch的数据。

返回:
分组完成的数据索引shape=(num_data/batch_size, batch_size)

比较抽象,下面举一个例子:

import torch.utils.data
from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler

a = [1,5,78,9,68]
b = BatchSampler(a, 2, False)
print(list(b))

在这里插入图片描述
可以看到已经分成三组,每组大小都是设置的batch_size=2。而drop_last=False,并未去掉于batch_size的分组。

🎄2.4 SubsetRandomSampler(子集随机采样)

torch.utils.data.SubsetRandomSampler(
	indices
)

参数:

  1. indices:数据集索引

返回:
与上面返回数据的索引不同,这里返回的是对应索引的数据本身

该方法更多应用于切分数据集,比如:

import torch.utils.data
from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler

a = [1,5,78,9,68]
b1 = torch.utils.data.SubsetRandomSampler(a[:3])
b2 = torch.utils.data.SubsetRandomSampler(a[3:])
for x in b1:
    print("train:", x)
for x in b2:
    print("val:", x)

在这里插入图片描述

🎃 2.5 WeightedRandomSampler(加权随机采样)

torch.utils.data.WeightedRandomSampler(
	 weights, 
	 num_samples, 
	 replacement=True)

参数:

  1. weights:采样到该索引的权重
  2. num_samples:指定采样的数量,默认是所有
  3. replacement:若为True,则表示可以重复采样,即同一个样本可以重复采样,这样可能导致有的样本采样不到。所以此时我们可以设置num_samples来增加采样数量使得每个样本都可能被采样到。

返回:
与上面返回数据的索引不同,这里返回的是对应索引的数据本身

示例代码:

import torch.utils.data
from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler

a = [1,5,78,9,68]
weights = [0, 3, 1.1, 1.1, 1.1, 1.1, 1.1]
b = WeightedRandomSampler(weights, 7, replacement=True)
for i in b:
    print(i)

在这里插入图片描述
代码中,replacement设置为True,允许重复采样后,由于位置1的权重为3比较大,因此被采样次数较多。

✨3 应用

了解上面五种Sampler后,如何在我们的项目中使用是重点:

  1. 采用
  2. DataLoader应用

🎃 3.1 采样

首先,创建顺序采样或随机采样,比如:

sampler = torch.utils.data.RandomSampler(train_dataset)  # train_dataset,自定义数据集(重载的DataSet)

其次,在上面的基础上创建批采样:

batch_sampler_train = torch.utils.data.BatchSampler(sampler, 16, drop_last=True)

结果类似:
在这里插入图片描述

🎉 3.2 DataLoader应用

其中,指定顺序采样或随机采样用到DatLoader的参数sampler。而指定批采样的参数是batch_sampler
由于参数之间可能冲突,使用时分为以下几种情况:

  1. sampler和batch_sampler都为None:batch_sampler使用Pytorch实现的批采样,而sampler分为两种情况
    ====================================================================
    a). shuffle=True,则sampler使用随机采样
    b). shuffle=False,则sampler使用顺序采样====================================================================
  2. 自定义了batch_sampler,那么batch_sizeshufflesamplerdrop_last必须都是默认值
  3. 自定义了sampler,此时batch_sampler不能再指定,且shuffle必须为False。
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

白三点

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值