1、模型中使用数据加载器的目的
在前面的线性回归模型中,我们使用的数据很少,所以直接把全部数据放到模型中去使用。但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成个个的batch,同时还会对数据进行预处理。
所以,接下来介绍pytorch中的数据加载的方法。
2、数据集类
2.1 Dataset基类介绍:
在torch中提供了数据集的基类
torch.utils.data.Dataset
, 继承这个基类,我们能够非常快速的实现对数据的加载。
torch.utils.data.Dataset
的源码如下:
from torch.utils.data import Dataset
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __len__(se1f):
raise NotImp lementedError
def __add__(se1f, other):
return ConcatDataset([self, other])
可知:我们需要在自定义的数据集类中