pytorch之sampeler dataset collate_fn dataloader 之间的关系

学习pytorch自然不可避免的会遇到需要自己定制数据集的情况,许多人在这一步时都会困惑于sampeler,dataset,collate_fn,dataloader这四者之间的关系,当然也包括我自己,写这个博文一是记录下来避免自己搞忘,二来希望对与我有同样困惑的人有些许帮助.

许多人为了搞清楚它们之间的关系,往往都是看源码,但是这样一来会花费大量时间,而且如果没有到达需要深入理解源码的层次,看了之后很快又会忘记.其实对于初学者,我们只需要记住结论即可,暂且不必深究其源码.所以下面我基本上只讲结论,同时以代码证明我的结论.具体递进方向为从顶向下.

## 1.dataloader接口

class DataLoader(object):
    def __init__(self, dataset,  # 数据集
                       batch_size=1,  # 每一次调用__iter__时,返回的数据长度
                       shuffle=False,  # 是否打乱顺序
                       sampler=None,  # 用于取数据index
                       batch_sampler=None,  # 用于取数据index(每次取一个batch_size长度数据)
                       num_workers=0,   # 进程数,如果多卡或者多机多卡训练,这里需要指定
                       collate_fn=None,   # 定义如何加载数据
                       pin_memory=False,   # 用于加速
                       drop_last=False,   # 当数据数量不能被batch_size整除时是都丢弃余数
                       timeout=0,   # 加载数据超时时间
                       worker_init_fn=None,   # 几乎不用动,默认值即可
                       multiprocessing_context=None):   # 几乎不用动,默认值即可

dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问.在访问时,可以通过"for inputs, labels in dataloaders:"的方式获取一个又一个batch_size的数据(取决于collate_fn对于数据的封装和dataset的返回).

但是问题在于,每次迭代中,这个dataloader是怎么获取到数据的呢?我们假设数据集是图片,在每次迭代中,dataloader怎么知道该读哪几张?又是怎么从硬盘读取到内存中?又是在内存中怎么组成一个batch的呢?

基于以上几个问题,引出下面的内容.

## 2. sampler与batch_sampler

这两者就是为了解决dataloader怎么知道读哪几张图片这个问题的.

### 2.1 sampler

在pytorch中,sampler有许多种,但都是以torch.utils.data.Sampler(data_source: Optional[collections.abc.Sized])为父类.其中参数data_source为dataset.具体参数就不详解了,否则太过冗长反而抓不住重点,各位条理清晰之后可以看一下官网说明:https://pytorch.org/docs/master/data.html#torch.utils.data.Sampler

1. torch.utils.data.SequentialSampler(data_source): 顺序采样器,采样顺序固定

dataset = MyDataSets(100)
r = list((torch.utils.data.SequentialSampler(dataset))
print(r)

输出为:

[0, 1, 2, 3, 4, 5...97, 98, 99]

2. torch.utils.data.RandomSampler(data_source: collections.abc.Sizedreplacement: bool = Falsenum_samples: Optional[int] = Nonegenerator=None): 无放回地随机采样样本元素

3. torch.utils.data.SubsetRandomSampler(indices: Sequence[int], generator=None): 无放回地按照给定的索引列表采样样本元素

4. torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本

5. torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集

这里重点说一下DistributedSampler,因为这个是在实际应用中,实现多机多卡训练的必选采样器.

在多机多卡的训练中,如何保证每个进程(即每一张显卡)读取到的数据没有重复?一种做法是选择一个比较大的batch_size,然后在每次采集到一个batch之后将数据拆分到所有进程,这样就可以保证每个进程没有重复数据,但是同时会带来一个很大的问题--在多机之间传输数据是非常耗时的,这样操作会大大降低训练速度.那么有没有一种采样器,使得在采样时就将整个数据集拆开,每个进程训练时就不需要管其他进程的数据,只需要从自己分到的数据集中采集一个batch的数据呢?DistributedSampler就是干这个事的.看以下代码印证:

dataset = MyDataSets(100)
r = list(torch.utils.data.DistributedSampler(dataset))
print(r)

当我在终端以" python -m torch.distributed.launch --nproc_per_node=2 distributedDataParallel.py "启动脚本时,输出为:

[44, 93, 71, 37, 53, 81, 80, 74, 76, 82, 26, 57, 16, 89, 97, 31, 35, 65, 98, 20, 78, 94, 4, 59, 54, 3, 2, 6, 49, 55, 79, 45, 32, 11, 22, 87, 25, 36, 9, 62, 18, 75, 13, 39, 64, 38, 14, 33, 86, 99]
[19, 90, 69, 95, 91, 42, 85, 56, 63, 40, 92, 10, 66, 41, 8, 24, 30, 7, 23, 29, 61, 15, 52, 5, 46, 28, 70, 60, 68, 72, 77, 1, 34, 0, 12, 50, 47, 96, 83, 84, 17, 67, 48, 21, 88, 27, 73, 58, 43, 51]

可以看到,数据集在不同进程的采样完全没有重复,那么各自进程的batch_sampler在打包数据的时候自然也不会重复.

从以上代码可以看出,sampler与dataset的关系在于sampler需要根据dataset生成index.

### 2.2 batch_sampler

当然,sampler只是生成一个整体的index,但是训练是以batch_size为迭代大小的,也就是说需要将这些index打包成一个一个的batch_size大小的包,每次取数据时就取每个包里面的index对应的数据.

而batch_sampler就是实现将index打包的类.在pytorch中实现的batch_sampler只有一个:torch.utils.data.BatchSampler(sampler, batch_size, drop_last).前两个参数不需要解释了,第三个参数为:当sampler采样到的数据长度不能被batch_size整除时,剩余的部分是否被丢弃.

下面看一下以DistributedSampler为采样器,使用BatchSampler之后的效果:

dataset = MyDataSets(100)
r = list(torch.utils.data.BatchSampler(torch.utils.data.DistributedSampler(dataset), batch_size=3, drop_last=False))
print(r)

当我在终端以" python -m torch.distributed.launch --nproc_per_node=2 distributedDataParallel.py "启动脚本时,输出为:

[[19, 90, 69], [95, 91, 42], [85, 56, 63], [40, 92, 10], [66, 41, 8], [24, 30, 7], [23, 29, 61], [15, 52, 5], [46, 28, 70], [60, 68, 72], [77, 1, 34], [0, 12, 50], [47, 96, 83], [84, 17, 67], [48, 21, 88], [27, 73, 58], [43, 51]]
[[44, 93, 71], [37, 53, 81], [80, 74, 76], [82, 26, 57], [16, 89, 97], [31, 35, 65], [98, 20, 78], [94, 4, 59], [54, 3, 2], [6, 49, 55], [79, 45, 32], [11, 22, 87], [25, 36, 9], [62, 18, 75], [13, 39, 64], [38, 14, 33], [86, 99]]

可以看到,整个数据集被DistributedSampler分为了不重叠的两部分,然后这两部分被各自的BatchSampler打包成batch_size大小的包.每次循环迭代时,dataloader回去读取对应index的数据组成一个batch.

当有自己的batch定制需求时BatchSampler完全可以自己定制,定制时需要实现__iter__与__len__方法.

## 3. dataset

dataset顾名思义就是数据集,也就是我们自己定义的东西,在能说明问题的情况下我尽量简单的定义一个用于线性拟合的数据集:

class MyDataSets(torch.utils.data.Dataset):
    def __init__(self, dataset_size):
        self.len = dataset_size
        self.x_train = torch.unsqueeze(torch.linspace(-1, 1, dataset_size), dim=1)
        self.x_train = self.x_train.view([dataset_size, 1])
        noise = torch.randn(self.x_train.size(), dtype=torch.float)
        self.y_train = 5 * self.x_train + 10 + noise

    def __getitem__(self, index):
        return self.x_train[index], self.y_train[index], index

    def __len__(self):
        return self.len

定义这个数据集有两点需要注意一下,第一点,数据集必须继承于torch.utils.data.Dataset,而且必须重写__getitem__与__len__方法,第二点,__getitem__返回的数据,可以在迭代dataloader时取出来,返回的内容并不固定,比如这里多此一举的多返回了一个index.这些大家应该都知道了,不过不知道有人注意到一个情况没有,就是在数据集__getitem__返回时,返回的是[data, label, index]的形式,那么取batch_size次之后组成的数据维度应该是[batch_size, 3],怎么可以用 "for inputs, labels,indexs in dataloaders" 的形式来取一个batch_size的数据呢?要这样迭代数据,被迭代的数据,应该是batch_size个inputs, labels, indexs各自一堆.collate_fn就实现了这个功能.

## 4. collate_fn

假设我们以最简单的形式定义和迭代dataloader,如下:

dataset = MyDataSets(100)
dataloader = DataLoader(dataset, batch_size=3)

在调用时会发生什么呢?说到这里,不可避免的还是得简单的过一下源码,此处省去了与这个调用无关的代码,默认collate_fn如下:

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    ......
    elif isinstance(elem, container_abcs.Sequence):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

在迭代dataloader时,每个batch都会以原始数据的形式丢入这个函数,在这里也就是维度为(batch_size, 3)的列表,然后进入判断,首先会进入最下面一个elif,对原始数据打包,再次迭代调用default_collate,这次迭代调用时分别是batch_size个input, batch_size个label以及batch_size个index作为参数传进来,然后进入第一个判断条件分支,torch.utils.data.get_worker_info()是判断是否是多线程,暂且不管,接着下面torch.stack将batch_size个输入拼接成一个大小为batch_size的向量.

这里有inputs, labels,indexs三个值,也就是需要迭代调用三次default_collate,将数据封装成三个batch_size大小的向量,最后组成一个列表,在每次迭代时,即可得到对应数据.

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值