献给学习PyTorch在路上或者计划较深入理解PyTorch的同行者们
写在前面
笔者一直使用tf,大势所趋决定转PyTorch,这个系列就作为我学习PyTorch的笔记与心得。
网络上PyTorch学习资源多如牛毛,因此我想写点不一样的。
我的计划:每篇文章介绍PyTorch一个/多个模块,内容主要为:
源码解析+实践demo
希望在代码层面更加透彻地理解和使用PyTorch(初步计划,后期可能会有改动)
1 源码解析
PyTorch的数据加载模块,一共涉及到Dataset,Sampler,Dataloader三个类
Dataset负责对raw data source封装,将其封装成Python可识别的数据结构,其必须提供提取数据个体的接口。Dataset共有Map-style datasets和Iterable-style datasets两种:
map-style dataset:实现了__getitem__和__len__接口,表示一个从索引/key到样本数据的map。比如:datasets[10],就表示第10个样本。
iterable-style dataset:实现了__iter__接口,表示在data samples上的一个Iterable(可迭代对象),这种形式的dataset非常不适合随机存取(代价太高),但非常适合处理流数据。比如:iter(datasets)获得迭代器,然后不断使用next迭代从而实现遍历。
Sampler负责提供一种遍历数据集所有元素索引的方式。
Dataloader负责加载数据,同时支持map-style和iterable-style Dataset,支持单进程/多进程,还可以设置loading order, batch size, pin memory等加载参数。
这三者的关系就一目了然了。
设置Dataset,将数据data source包装成Dataset类,暴露提取接口。
设置Sampler,决定采样方式。我们是能从Dataset中提取元素了,还是需要设置Sampler告诉程序提取Dataset的策略。
将设置好的Dataset和Sampler传入DataLoader,同时可以设置shuffle,batch_size等参数。使用DataLoader对象可以快捷方便地在给定数据集上遍历。
总结来说,即Dataloader负责总的调度,命令Sampler定义遍历索引的方式,然后用索引去Dataset中提取元素。于是就实现了对给定数据集的遍历。
1.1 Dataset
torch.utils.data.Dataset:抽象基类
所有的Dataset相关类都应该继承自torch.utils.data.Dataset这个类
这些子类必须要实现方法__getitem__(),来支持可以给定一个key(即索引)来获取对应的数据样本
这些类可以实现方法__len__(),来返回数据集的大小规模
Dataset实现非常简洁,就只是提供了__getitem__ 和 __add__这两个接口。
前者很重要,是Dataset及其子类的核心,定义了数据元素提取(即通过索引获取样本,实际代码中常使用[]输入索引)
class Dataset(object):def __getitem__(self, index): raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self, other])
具体实践中,我们需要使用Dataset的子类,自己实现的或者现成的。
我们可以来看看PyTorch为我们提供的现成的Dataset子类:
TensorDataset
IterableDataset
ConcatDataset
ChainDataset
Subset
下面着重介绍TensorDataset和IterableDataset.
CLASS torch.utils.data.TensorDataset(*tensors)
包装了Tensor的Dataset子类,map-style dataset
每个样本可以通过tensors第一个维度的索引获取
class TensorDataset(Dataset):r""" Arguments: *tensors (Tensor): tensors that have the same size of the first dimension. """def __init__(self, *tensors):assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) self.tensors = tensorsdef __getitem__(self, index):return tuple(tensor[index] for tensor in self.tensors)def __len__(self):return self.tensors[0].size(0)
如上源码:
__init__的形参是*tensors,因此是可以传入多个tensor变量的,但需要保证每个tensor的第一个维度均是一样的。
栗子:输入张量:
正确输入:
错误输入:
__getitem__提取的就是*tensors中每个张量的第index个样本(因为每个张量第一维度都是一样的)
__len__即*tensors每个张量第一个维度长度
常见用法:*tensors指定我们可以输入多个张量,我们可以同时输入train_data和train_label
dataset = TensorDataset(train_data, train_label)
CLASS torch.utils.data.IterableDataset
内部样本的组织形式是Iterable的所有dataset类都是IterableDataset类的子类,即:所有iterable-style dataset都是IterableDataset的子类。
这种形式的dataset对于处理流数据是非常有用的。
所有这些子类需要实现__iter__方法(而不是__getitem__方法了),需要据此来返回样本的迭代器,从而遍历dataset(实际代码中常使用iter+next来遍历)
关于Python中Iterable和Iterator的介绍见我的另一篇文章: 刘昕宸:彻底搞懂Python的__iter__和__next__,Iterable和Iteration
class IterableDataset(Dataset[T_co]):def __iter__(self) -> Iterator[T_co]:raise NotImplementedErrordef __add__(self, other: Dataset[T_co]):return ChainDataset([self, other])
关于多进程的问题:
IterableDataset的某个子类被DataLoader使用时,dataset中的每个item可以通过DataLoader的Iterator迭代获取。
当num_works>0时就是多进程模式,每个工作进程都有一个不同的dataset对象的拷贝,因此我们需要独立安排每一份拷贝该如何处理(后面会有例子),以防止不同的进程会返回重复的元素。(有MPI编程经验的同学应该更能理解!)
可以通过get_worker_info方法,在某一当前进程中调用,获得当前进程信息。这个方法要么在dataset类的__iter__方法中使用,要么在DataLoader的worker_init_fn方法中设置并使用。
举2个例子(来自官网文档):
例1:在dataset类的__iter__方法中使用get_worker_info方法,划分工作空间,获得当前进程id,并根据进程id分配其需要处理的工作空间
class MyIterableDataset(torch.utils.data.IterableDataset):def __init__(self, start, end):super(MyIterableDataset).__init__() assert end > start, "this example code only works with end >= start"self.start = startself.end = enddef __iter__(self): worker_info = torch.utils.data.get_worker_info()if worker_info is None: # 单进程:一个进程处理全部样本 iter_start = self.start iter_end = self.endelse: # 多进程,在当前进程中# 划分工作空间