🎏目录
🎈1 DataSet、DataLoader和Sampler的关系
🎈2 Sampler
🎄2.1 SequentialSampler(顺序采样)
🎄2.2 RandomSampler(随即采样)
🎄2.3 BatchSampler(批采样)
🎄2.4 SubsetRandomSampler(子集随机采样)
🎄2.5 WeightedRandomSampler(加权随机采样)
✨1 DataSet、DataLoader和Sampler的关系
我们知道DataSet建立数据集,本质是读取一张张图像。而DataLoader是将DataSet中的图像一个个取出来,打包成一个个batch。
但是这里存在一个问题,DataLoader从Dataet中是如何取一张张图像的,该问题对我们训练也有影响:
假设我们数据集是按照类别放在一起的,那么DataSet的读取的图像也是按照类别放在一起的。此时,如果DataLoader顺序读取打包,则可能出现每个batch中都是同一个类别的图像。这就会影响我们模型的训练效果。
因此需要Sampler决定打包时的读取图像的顺序。这就是三者之间的关系。
✨2 Sampler
Pytorch中实现了五种Sampler:
- SequentialSampler(顺序采样)
- RandomSampler(随机采样)
- WeightedSampler(加权随机采样)
- SubsetRandomSampler(子集随机采样)
- BatchSampler(批采样)
(其中1,2,5可应用到DataLoader中,第三节详细展开)
🎃 2.1 SequentialSampler(顺序采样)
用于获取数据索引
torch.utils.data.SequentialSampler(
data_source,
)
参数:
- data_source:可迭代数据,一般为数据集
返回:
顺序返回数据集索引
示例:
🎉 2.2 RandomSampler(随即采样)
用于获取打乱的数据索引
torch.utils.data.SequentialSampler(
data_source,
num_samples,
replacement,
)
参数:
- data_source:同上
- num_samples:指定采样的数量,默认是所有
- replacement:若为True,则表示可以重复采样,即同一个样本可以重复采样,这样可能导致有的样本采样不到。所以此时我们可以设置num_samples来增加采样数量使得每个样本都可能被采样到。
返回:
乱序返回数据集索引
🎄2.3 BatchSampler(批采样)
BatchSampler将前面的Sampler采样得到的单个的索引值进行合并,当数量等于一个batch大小后就将这一批的索引值返回。(训练时使用的是批量数据)
torch.utils.data.BatchSampler(
sampler,
batch_size,
drop_last,
)
参数:
- sampler:上述两种采样器,即SequentialSampler或RandomSampler
- batch_size:batch的大小
- drop_last:True或False。drop_last为True时,如果采样得到的数据个数小于batch_size则抛弃本个batch的数据。
返回:
分组完成的数据索引shape=(num_data/batch_size, batch_size)
比较抽象,下面举一个例子:
import torch.utils.data
from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler
a = [1,5,78,9,68]
b = BatchSampler(a, 2, False)
print(list(b))
可以看到已经分成三组,每组大小都是设置的batch_size=2。而drop_last=False,并未去掉于batch_size的分组。
🎄2.4 SubsetRandomSampler(子集随机采样)
torch.utils.data.SubsetRandomSampler(
indices
)
参数:
- indices:数据集索引
返回:
与上面返回数据的索引不同,这里返回的是对应索引的数据本身
该方法更多应用于切分数据集,比如:
import torch.utils.data
from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler
a = [1,5,78,9,68]
b1 = torch.utils.data.SubsetRandomSampler(a[:3])
b2 = torch.utils.data.SubsetRandomSampler(a[3:])
for x in b1:
print("train:", x)
for x in b2:
print("val:", x)
🎃 2.5 WeightedRandomSampler(加权随机采样)
torch.utils.data.WeightedRandomSampler(
weights,
num_samples,
replacement=True,
)
参数:
- weights:采样到该索引的权重
- num_samples:指定采样的数量,默认是所有
- replacement:若为True,则表示可以重复采样,即同一个样本可以重复采样,这样可能导致有的样本采样不到。所以此时我们可以设置num_samples来增加采样数量使得每个样本都可能被采样到。
返回:
与上面返回数据的索引不同,这里返回的是对应索引的数据本身
示例代码:
import torch.utils.data
from torch.utils.data import BatchSampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler
a = [1,5,78,9,68]
weights = [0, 3, 1.1, 1.1, 1.1, 1.1, 1.1]
b = WeightedRandomSampler(weights, 7, replacement=True)
for i in b:
print(i)
代码中,replacement
设置为True,允许重复采样后,由于位置1的权重为3比较大,因此被采样次数较多。
✨3 应用
了解上面五种Sampler后,如何在我们的项目中使用是重点:
- 采用
- DataLoader应用
🎃 3.1 采样
首先,创建顺序采样或随机采样,比如:
sampler = torch.utils.data.RandomSampler(train_dataset) # train_dataset,自定义数据集(重载的DataSet)
其次,在上面的基础上创建批采样:
batch_sampler_train = torch.utils.data.BatchSampler(sampler, 16, drop_last=True)
结果类似:
🎉 3.2 DataLoader应用
其中,指定顺序采样或随机采样用到DatLoader的参数sampler
。而指定批采样的参数是batch_sampler
。
由于参数之间可能冲突,使用时分为以下几种情况:
- sampler和batch_sampler都为None:batch_sampler使用Pytorch实现的批采样,而sampler分为两种情况
====================================================================
a).shuffle=True
,则sampler使用随机采样
b).shuffle=False
,则sampler使用顺序采样==================================================================== - 自定义了
batch_sampler
,那么batch_size
,shuffle
,sampler
,drop_last
必须都是默认值 - 自定义了
sampler
,此时batch_sampler
不能再指定,且shuffle
必须为False。