dataloader 源码_PyTorch源码解读之torch.utils.data.DataLoader

PyTorch中数据读取的一个重要接口是torch.utils.data.DataLoader,该接口定义在dataloader.py脚本中,只要是用PyTorch来训练模型基本都会用到该接口,该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入,因此该接口有点承上启下的作用,比较重要。这篇博客介绍该接口的源码,主要包含DataLoader和DataLoaderIter两个类。

理解这些类需要知道的python知识点

能够熟练的使用python语言的技巧,是理解pytorch源码的关键。在torch.utils.data.Dataset和torch.utils.data.DataLoader这两个类中会用到python抽象类的魔法方法,包括__len__(self),__getitem__(self)和_iter__(self)__len__(self)定义当被len()函数调用时的行为(返回容器中元素的个数)

__getitem__(self)定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作。

__iter__(self)定义当迭代容器中的元素的行为

接下来讲解__iter__(self)方法。这个魔法方法是在python构造迭代器的时候需要定义的。迭代的意思类似于循环,每一次重复的过程被称为一次迭代的过程,而每一次迭代得到的结果会被用来作为下一次迭代的初始值。提供迭代方法的容器称为迭代器,通常接触的迭代器有序列(列表、元组和字符串)还有字典也是迭代器,都支持迭代操作。那么实现迭代器的魔法方法有两个:__iter__()

__next__()

一个容器如果是迭代器,那就必须实现__iter__()魔法方法,这个方法实际上是返回迭代器本身。接下来重点要实现的是__next__()魔法方法,因为它决定了迭代的规则。举个简单的例子:

class Fibs:

def __init__(self, n=20):

self.a = 0

self.b = 1

self.n = n

def __iter__(self):

return self

def __next__(self):

self.a, self.b = self.b, self.a + self.b

if self.a > self.n:

raise StopIteration

return self.a

## 调用fibs = Fibs()

for each in fibs:

print(each)

## 输出1

1

2

3

5

8

13

torch.utils.data.DataLoader参数:dataset (Dataset) – 加载数据的数据集。

batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。

shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).

sampler (Sampler, optional) – 定义从数据集中提取样本的策略,即生成index的方式,可以顺序也可以乱序

num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)

collate_fn (callable, optional) –将一个batch的数据和标签进行合并操作。

pin_memory (bool, optional) –设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。

drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。

在__init__中,RandomSampler类表示随机采样且不重复,所以起到的就是shuffle的作用。BatchSampler类则是把batch size个RandomSampler类对象封装成一个,这样就实现了随机选取一个batch的目的。这两个采样类都是定义在sampler.py脚本中,地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py。以上这些都是初始化的时候进行的。当代码运行到要从torch.utils.data.DataLoader类生成的对象中取数据的时候,比如:

train_loader = DataLoader(dataset=train_set, batch_size=2, shuffle=True)

for i, data in enumerate(train_loader):

inputs, labels = data

就会调用DataLoader类的__iter__方法,__iter__方法就一行代码:return DataLoaderIter(self),输入正是DataLoader类的属性。因此当调用__iter__方法的时候就牵扯到另外一个类:DataLoaderIter,接下来介绍。

class DataLoaderIter(object):

"Iterates once over the DataLoader's dataset, as specified by the sampler"

def __init__(self, loader):

self.dataset = loader.dataset

self.collate_fn = loader.collate_fn

self.batch_sampler = loader.batch_sampler

self.num_workers = loader.num_workers

self.pin_memory = loader.pin_memory and torch.cuda.is_available()

self.timeout = loader.timeout

self.done_event = threading.Event()

self.sample_iter = iter(self.batch_sampler)

if self.num_workers > 0:

self.worker_init_fn = loader.worker_init_fn

self.index_queue = multiprocessing.SimpleQueue()

self.worker_result_queue = multiprocessing.SimpleQueue()

self.batches_outstanding = 0

self.worker_pids_set = False

self.shutdown = False

self.send_idx = 0

self.rcvd_idx = 0

self.reorder_dict = {}

base_seed = torch.LongTensor(1).random_()[0]

self.workers = [

multiprocessing.Process(

target=_worker_loop,

args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,

base_seed + i, self.worker_init_fn, i))

for i in range(self.num_workers)]

if self.pin_memory or self.timeout > 0:

self.data_queue = queue.Queue()

self.worker_manager_thread = threading.Thread(

target=_worker_manager_loop,

args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,

torch.cuda.current_device()))

self.worker_manager_thread.daemon = True

self.worker_manager_thread.start()

else:

self.data_queue = self.worker_result_queue

for w in self.workers:

w.daemon = True # ensure that the worker exits on process exit

w.start()

_update_worker_pids(id(self), tuple(w.pid for w in self.workers))

_set_SIGCHLD_handler()

self.worker_pids_set = True

# prime the prefetch loop for _ in range(2 * self.num_workers):

self._put_indices()

首先执行的是DataLoader的初始化工作,比如依次进入DataLoader(object)类的__iter__(self)方法,选择单进程还是多进程的DataLoader迭代器,然后默认进入单进程的DataLoader迭代器类_SingleProcessDataLoaderIter(_BaseDataLoaderIter)初始化方法__init__(self, loader),然后进入_BaseDataLoaderIter(object)类的初始化方法 init(self, loader)等。

class DataLoader(object):

# ......

def __iter__(self):

if self.num_workers == 0:

return _SingleProcessDataLoaderIter(self)

else:

return _MultiProcessingDataLoaderIter(self)

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):

def __init__(self, loader):

super(_SingleProcessDataLoaderIter, self).__init__(loader)

assert self._timeout == 0

assert self._num_workers == 0

self._dataset_fetcher = _DatasetKind.create_fetcher(

self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

class _BaseDataLoaderIter(object):

def __init__(self, loader):

self._dataset = loader.dataset

self._dataset_kind = loader._dataset_kind

self._IterableDataset_len_called = loader._IterableDataset_len_called

self._auto_collation = loader._auto_collation

self._drop_last = loader.drop_last

self._index_sampler = loader._index_sampler

self._num_workers = loader.num_workers

self._pin_memory = loader.pin_memory and torch.cuda.is_available()

self._timeout = loader.timeout

self._collate_fn = loader.collate_fn

self._sampler_iter = iter(self._index_sampler)

self._base_seed = torch.empty((), dtype=torch.int64).random_().item()

self._num_yielded = 0

DataLoader的初始化完成之后,便开始读数据。首先进入的是_BaseDataLoaderIter(object)类def __next__(self)方法。然后跳转到_SingleProcessDataLoaderIter类的_next_data(self)方法,然后进入BatchSampler(Sampler)采样器类的__iter__(self)方法,就是在这里生成一个batch数据的所有索引值,并放到一个列表中,如图4所示。再次进入_SingleProcessDataLoaderIter类的_next_data(self)方法,把生成的index给_dataset_fetcher.fetch(index) 。最后跳到我们自定义的DataSet(Dataset):类的__getitem__(self, index),一次给它一个索引,循环batch_size次。循环完成之后默认将进入default_collate(batch)方法,整合这个batch的数据。

总的来说,如果设置为多进程读取数据,那么就会采用队列的方式来读,如果不是采用多进程来读取数据,那就采用普通方式来读.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值