Pytorch中的数据加载
目标
-
知道数据加载的目的
-
知道pytorch中Dataset的使用方法
-
知道pytorch中DataLoader的使用方法
-
知道pytorch中的自带数据集如何获取
1. 模型中使用数据加载器的目的
在前面的线性回归模型中,我们使用的数据很少,所以直接把全部数据放到模型中去使用。
但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时还会对数据进行预处理。
所以,接下来我们来学习pytorch中的数据加载的方法
2. 数据集类
2.1 Dataset基类介绍
在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够非常快速的实现对数据的加载。
torch.utils.data.Dataset的源码如下:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
可知:我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:
-
__len__方法,能够实现通过全局的len()方法获取其中的元素个数
-
__getitem__方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据
2.2 数据加载案例
下面通过一个例子来看看如何使用Dataset来加载数据
数据来源:http: