写在之前
介绍
Pytorch深度学习框架优势之一是python优先,源代码由python代码层和C语言代码层组成,一般只需要理解python代码层就可以深入理解pytorch框架的计算原理。所以学习pytorch源码需要熟练掌握python语言的各种使用技巧。
在处理任何机器学习问题之前都需要数据读取,并进行预处理。Pytorch提供了许多方法使得数据读取和预处理变得很容易。
torch.utils.data.Dataset
是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len__
和__getitem__
这两个方法就可以。- 通过继承
torch.utils.data.Dataset
的这个抽象类,我们可以定义好我们需要的数据类。当我们通过迭代的方式来取得每一个数据,但是这样很难实现取batch,shuffle或者多线程读取数据,所以pytorch还提供了一个简单的方法来做这件事情,通过torch.utils.data.DataLoader
类来定义一个新的迭代器,用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
总之,通过torch.utils.data.Dataset
和torch.utils.data.DataLoader
这两个类,使数据的读取变得非常简单,快捷。
这两个抽象类中用到的python知识点
能够熟练的使用python语言的技巧,是理解pytorch源码的关键。在torch.utils.data.Dataset
和torch.utils.data.DataLoader
这两个类中会用到python抽象类的魔法方法,包括__len__(self)
,__getitem__(self)
和__iter__(self)
__len__(self)
定义当被len()
函数调用时的行为(返回容器中元素的个数)__getitem__(self)
定义获取容器中指定元素的行为,相当于self[key]
,即允许类对象可以有索引操作。__iter__(self)
定义当迭代容器中的元素的行为
下面通过介绍python定制容器的方式来介绍__len__(self)
,__getitem__(self)
两种方法。
在python中,像序列类型(如列表,元组和字符串)或映射类型(如字典)都属于容器类型。讲定制容器,那就必须要知道,定制容器有关的一些协议:
- 如果你希望定制的容器是不可变的话,你只需要定义
__len__()
和__getitem__
这两个魔法方法。 - 如果你希望定制的容器是可变的话,除了
__len__()
和__getitem__
这两个魔法方法,还需要定义__setitem__()
和__delitem__()
两个方法。
小案例:编写一个不可变的自定义列表,要求记录列表中每个元素被访问的次数。
class CountList:
def __init__(self, *args):
self.values = [x for x in args]
self.count = {
}.fromkeys(range(len(self.values)),0)
# 这里使用列表的下标作为字典的键,注意不能用元素作为字典的键
# 因为列表的不同下标可能有值一样的元素,但字典不能有两个相同的键
def __len__(self):
return len(self.values)
def __getitem__(self, key):
self.count[key] += 1
return self.values[key]
c1 = CountList(1,3,5,7,9)
c2 = CountLIst(2,4,6,8,10)
# 调用
c1[1] ## 3
c2[1] ## 4
c1[1] + c2[1] ## 7
c1.count ## {0:0,1:2,2:0,3:0,4:0}
c2.count ## {0:0,1:2,2:0,3:0,4:0}
接下来讲解__iter__(self)
方法。这个魔法方法是在python构造迭代器的时候需要定义的。迭代的意思类似于循环,每一次重复的过程被称为一次迭代的过程,而每一次迭代得到的结果会被用来作为下一次迭代的初始值。提供迭代方法的容器称为迭代器,通常接触的迭代器有序列(列表、元组和字符串)还有字典也是迭代器,都支持迭代操作。那么实现迭代器的魔法方法有两个:
__iter__()
__next__()
一个容器如果是迭代器,那就必须实现__iter__()
魔法方法,这个方法实际上是返回迭代器本身。接下来重点要实现的是__next__()
魔法方法,因为它决定了迭代的规则。举个简单的例子:
class Fibs:
def __init__(self, n=20):
self.a = 0
self.b = 1
self.n = n
def __iter__(self):
return self
def __next__(self):
self.a, self.b = self.b, self.a + self.b
if self.a > self.n:
raise StopIteration
return self.a
## 调用
fibs = Fibs()
for each in fibs:
print(each)
## 输出
1
1
2
3
5
8
13
torch.utils.data.Dataset类
源码:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
一个用来表示数据集的抽象类,其他所有的数据集都应该是这个类的子类,并且需要重写__len__
和__getitem__
。
torch.utils.data.DataLoader类
DataLoader类源码如下。先看看__init__
中的几个重要的输入:1、dataset,这个就是PyTorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。2、batch_size,根据具体情况设置即可。3、shuffle,一般在训练数据中会采用。4、collate_fn,是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。5、batch_sampler,从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。6、sampler,从代码可以看出,其和shuffle是互斥的,一般默认即可。7、num_workers,从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。8、pin_memory,注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。9、timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
在__init__中,RandomSampler类表示随机采样且不重复,所以起到的就是shuffle的作用。BatchSampler类则是把batch size个RandomSampler类对象封装成一个,这样就实现了随机选取一个batch的目的。这两个采样类都是定义在sampler.py脚本中,地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py。以上这些都是初始化的时候进行的。当代码运行到要从torch.utils.data.DataLoader类生成的对象中取数据的时候,比如:
train_data=torch.utils.data.DataLoader(…)
for i, (input, target) in enumerate(train_data):
…
就会调用DataLoader类的__iter__方法,__iter__方法就一行代码:return DataLoaderIter(self),输入正是DataLoader类的属性。因此当调用__iter__方法的时候就牵扯到另外一个类:DataLoaderIter,接下来介绍。
class DataLoader(object):
r"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-proc