由于我们不能将大量数据一次性放入网络中进行训练,所以需要分批进行数据读取。这一过程涉及到如何从数据集中读取数据的问题,pytorch提供了Sampler基类【1】与多个子类实现不同方式的数据采样。子类包含:
- Sequential Sampler(顺序采样)
- Random Sampler(随机采样)
- Subset Random Sampler(子集随机采样)
- Weighted Random Sampler(加权随机采样)等等。
1、基类Sampler
class Sampler(object):
r"""Base class for all Samplers.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
对于所有的采样器来说,都需要继承Sampler类,必须实现的方法为__iter__(),也就是定义迭代器行为,返回可迭代对象。除此之外,Sampler类并没有定义任何其它的方法。
2、顺序采样Sequential Sampler
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
顺序采样类并没有定义过多的方法,其中初始化方法仅仅需要一个Dataset类对象作为参数。对于__len__()只负责返回数据源包含的数据个数;__iter__()方法负责返回一个可迭代对象,这个可迭代对象是由range产生的顺序数值序列,也就是说迭代是按照顺序进行的。前述几种方法都只需要self.data_source实现了__len__()方法,因为这几种方法都仅仅使用了len(self.data_source)函数。所以下面采用同样实现了__len__()的list类型来代替Dataset类型做测试:
# 定义数据和对应的采样器
data = list([17, 22, 3, 41, 8])
seq_sampler &