dataloader 源码_PyTorch之DataLoader杂谈

输入数据PipeLine

pytorch 的数据加载到模型的操作顺序是这样的:

①创建一个 Dataset 对象

②创建一个 DataLoader 对象

③循环这个 DataLoader 对象,将img, label加载到模型中进行训练

dataset = MyDataset()

dataloader = DataLoader(dataset)

num_epoches = 100

for epoch in range(num_epoches):

for img, label in dataloader:

....

所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。

首先简单介绍一下DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在

DataLoader的官方说明是:

“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”

关于iterator和iterable的区别和概念若有兴趣请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。

1.DataLoader源码剖析

在分析源码之前,先介绍一下DataLoader(object)的参数:

dataset(Dataset): 传入的数据集

batch_size(int, optional): 每个batch有多少个样本

shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序

sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数

pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

显然,根据上面参数的解释,DataLoader这个类就是进行数据的初始化的操作,好了,下面来看源码吧:

class DataLoader(object):

__initialized = False

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,

batch_sampler=None, num_workers=0, collate_fn=default_collate,

pin_memory=False, drop_last=False, timeout=0,

worker_init_fn=None):

self.dataset = dataset

self.batch_size = batch_size

self.num_workers = num_workers

self.collate_fn = collate_fn

self.pin_memory = pin_memory

self.drop_last = drop_last

self.timeout = timeout

self.worker_init_fn = worker_init_fn

if timeout < 0:

raise ValueError('timeout option should be non-negative')

if batch_sampler is not None:

if batch_size > 1 or shuffle or sampler is not None or drop_last:

raise ValueError('batch_sampler option is mutually exclusive '

'with batch_size, shuffle, sampler, and '

'drop_last')

self.batch_size = None

self.drop_last = None

if sampler is not None and shuffle:

raise ValueError('sampler option is mutually exclusive with '

'shuffle')

if self.num_workers < 0:

raise ValueError('num_workers option cannot be negative; '

'use num_workers=0 to disable multiprocessing.')

if batch_sampler is None:

if sampler is None:

if shuffle:

sampler = RandomSampler(dataset)

else:

sampler = SequentialSampler(dataset)

batch_sampler = BatchSampler(sampler, batch_size, drop_last)

self.sampler = sampler

self.batch_sampler = batch_sampler

self.__initialized = True

def __setattr__(self, attr, val):

if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):

raise ValueError('{} attribute should not be set after {} is '

'initialized'.format(attr, self.__class__.__name__))

super(DataLoader, self).__setattr__(attr, val)

def __iter__(self):

return _DataLoaderIter(self)

def __len__(self):

return len(self.batch_sampler)

这里主要看__init__()和__iter__():

①数据的shuffle和batch处理

RandomSampler(dataset)

SequentialSampler(dataset)

BatchSampler(sampler, batch_size, drop_last)

②因为DataLoader只有__iter__()而没有实现__next__(),所以DataLoader是一个iterable而不是iterator。这个iterator的实现在_DataLoaderIter中。

1.1DataLoader之RandomSampler(dataset)、 SequentialSampler(dataset)

实现是在dataloader.py的同级目录下的torch/utils/data/sampler.py。sampler.py中实现了一个父类Sampler,以及SequentialSampler,RandomSampler和BatchSampler等五个继承Sampler的子类。对每个采样器,都需要提供__iter__方法用以表示数据遍历的方式和__len__方法用以返回数据的长度。

class Sampler(object):

r"""Base class for all Samplers.

Every Sampler subclass has to provide an __iter__ method, providing a way

to iterate over indices of dataset elements, and a __len__ method that

returns the length of the returned iterators.

"""

def __init__(self, data_source):

pass

def __iter__(self):

raise NotImplementedError

def __len__(self):

raise NotImplementedError

class SequentialSampler(Sampler):

r"""Samples elements sequentially, always in the same order.

Arguments:

data_source (Dataset): dataset to sample from

"""

def __init__(self, data_source):

self.data_source = data_source

def __iter__(self):

return iter(range(len(self.data_source)))

def __len__(self):

return len(self.data_source)

class RandomSampler(Sampler):

r"""Samples elements randomly, without replacement.

Arguments:

data_source (Dataset): dataset to sample from

"""

def __init__(self, data_source):

self.data_source = data_source

def __iter__(self):

return iter(torch.randperm(len(self.data_source)).tolist())

def __len__(self):

return len(self.data_source)

if __name__ == "__main__":

print(list(RandomSampler(range(10))))

#[2, 8, 3, 5, 9, 4, 6, 0, 1, 7]

print(list(SequentialSampler(range(10))))

#[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

可以看出Ra

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值