pytorch源码分析之torch.utils.data.Dataset类和torch.utils.data.DataLoader类

本文介绍了PyTorch中用于数据读取的torch.utils.data.Dataset和DataLoader类,阐述了它们的作用和用法。Dataset是自定义数据集的抽象类,DataLoader则提供批量、shuffle和多线程读取数据的功能。理解这两个类涉及的Python抽象类魔法方法,如__len__、__getitem__等。文章详细解析了Dataset和DataLoader的源码,包括DataLoader的工作流程,如多进程数据读取和pin_memory机制。
摘要由CSDN通过智能技术生成
写在之前
介绍

Pytorch深度学习框架优势之一是python优先,源代码由python代码层和C语言代码层组成,一般只需要理解python代码层就可以深入理解pytorch框架的计算原理。所以学习pytorch源码需要熟练掌握python语言的各种使用技巧。

在处理任何机器学习问题之前都需要数据读取,并进行预处理。Pytorch提供了许多方法使得数据读取和预处理变得很容易。

  • torch.utils.data.Dataset是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len____getitem__这两个方法就可以。
  • 通过继承torch.utils.data.Dataset的这个抽象类,我们可以定义好我们需要的数据类。当我们通过迭代的方式来取得每一个数据,但是这样很难实现取batch,shuffle或者多线程读取数据,所以pytorch还提供了一个简单的方法来做这件事情,通过torch.utils.data.DataLoader类来定义一个新的迭代器,用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
    总之,通过torch.utils.data.Datasettorch.utils.data.DataLoader这两个类,使数据的读取变得非常简单,快捷。
这两个抽象类中用到的python知识点

能够熟练的使用python语言的技巧,是理解pytorch源码的关键。在torch.utils.data.Datasettorch.utils.data.DataLoader这两个类中会用到python抽象类的魔法方法,包括__len__(self)__getitem__(self)__iter__(self)

  • __len__(self) 定义当被len()函数调用时的行为(返回容器中元素的个数)
  • __getitem__(self)定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作。
  • __iter__(self)定义当迭代容器中的元素的行为

下面通过介绍python定制容器的方式来介绍__len__(self)__getitem__(self)两种方法。
在python中,像序列类型(如列表,元组和字符串)或映射类型(如字典)都属于容器类型。讲定制容器,那就必须要知道,定制容器有关的一些协议:

  • 如果你希望定制的容器是不可变的话,你只需要定义__len__()__getitem__这两个魔法方法。
  • 如果你希望定制的容器是可变的话,除了__len__()__getitem__这两个魔法方法,还需要定义__setitem__()__delitem__()两个方法。

小案例:编写一个不可变的自定义列表,要求记录列表中每个元素被访问的次数。

class CountList:
	def __init__(self, *args):
		self.values = [x for x in args]
		self.count = {
   }.fromkeys(range(len(self.values)),0)
		# 这里使用列表的下标作为字典的键,注意不能用元素作为字典的键
		# 因为列表的不同下标可能有值一样的元素,但字典不能有两个相同的键
	def __len__(self):
		return len(self.values)
	def __getitem__(self, key):
		self.count[key] += 1
		return self.values[key]
c1 = CountList(1,3,5,7,9)
c2 = CountLIst(2,4,6,8,10)

# 调用
c1[1]  ## 3
c2[1]  ## 4
c1[1] + c2[1] 	## 7
c1.count  ## {0:0,1:2,2:0,3:0,4:0}
c2.count  ## {0:0,1:2,2:0,3:0,4:0}	

接下来讲解__iter__(self)方法。这个魔法方法是在python构造迭代器的时候需要定义的。迭代的意思类似于循环,每一次重复的过程被称为一次迭代的过程,而每一次迭代得到的结果会被用来作为下一次迭代的初始值。提供迭代方法的容器称为迭代器,通常接触的迭代器有序列(列表、元组和字符串)还有字典也是迭代器,都支持迭代操作。那么实现迭代器的魔法方法有两个:

  • __iter__()
  • __next__()
    一个容器如果是迭代器,那就必须实现__iter__()魔法方法,这个方法实际上是返回迭代器本身。接下来重点要实现的是__next__()魔法方法,因为它决定了迭代的规则。举个简单的例子:
class Fibs:
	def __init__(self, n=20):
		self.a = 0
		self.b = 1
		self.n = n
	def __iter__(self):
		return self
	def __next__(self):
		self.a, self.b = self.b, self.a + self.b
		if self.a > self.n:
			raise StopIteration
		return self.a

## 调用
fibs = Fibs()
for each in fibs:
	print(each)
## 输出
1
1
2
3
5
8
13
torch.utils.data.Dataset类

源码:

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

一个用来表示数据集的抽象类,其他所有的数据集都应该是这个类的子类,并且需要重写__len____getitem__

torch.utils.data.DataLoader类

DataLoader类源码如下。先看看__init__中的几个重要的输入:1、dataset,这个就是PyTorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。2、batch_size,根据具体情况设置即可。3、shuffle,一般在训练数据中会采用。4、collate_fn,是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。5、batch_sampler,从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。6、sampler,从代码可以看出,其和shuffle是互斥的,一般默认即可。7、num_workers,从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。8、pin_memory,注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。9、timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
在__init__中,RandomSampler类表示随机采样且不重复,所以起到的就是shuffle的作用。BatchSampler类则是把batch size个RandomSampler类对象封装成一个,这样就实现了随机选取一个batch的目的。这两个采样类都是定义在sampler.py脚本中,地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py。以上这些都是初始化的时候进行的。当代码运行到要从torch.utils.data.DataLoader类生成的对象中取数据的时候,比如:
train_data=torch.utils.data.DataLoader(…)
for i, (input, target) in enumerate(train_data):

就会调用DataLoader类的__iter__方法,__iter__方法就一行代码:return DataLoaderIter(self),输入正是DataLoader类的属性。因此当调用__iter__方法的时候就牵扯到另外一个类:DataLoaderIter,接下来介绍。

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-proc
  • 79
    点赞
  • 360
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值