深度时代,数据为王。
PyTorch为我们提供的两个Dataset和DataLoader类分别负责可被Pytorhc使用的数据集的创建以及向训练传递数据的任务。如果想个性化自己的数据集或者数据传递方式,也可以自己重写子类。
Dataset是DataLoader实例化的一个参数,所以这篇文章会先从Dataset的源代码讲起,然后下一篇讲到DataLoader,关注主要函数,少细枝末节,目的是使大家学会自定义自己的数据集。
ps: 本文搬运自作者的博客 陈亮的博客 | Liang's Blog,里面有一些完成/待完成的文章,欢迎大家一起交流,转载请注明。
Dataset
什么时候使用Dataset
CIFAR10是CV训练中经常使用到的一个数据集,在PyTorch中CIFAR10是一个写好的Dataset,我们使用时只需以下代码:
data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)
datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。
我们有的时候需要用自己在一个文件夹中的数据作为数据集,这个时候,我们可以使用ImageFolder这个方便的API。
FaceDataset = datasets.ImageFolder('./data', transform=img_transform)
如何自定义一个数据集
torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。
所谓数据集,其实就是一个负责处理索引(index)到样本(sample)映射的一个类(class)。
Pytorch提供两种数据集: Map式数据集 Iterable式数据集
Map式数据集
一个Map式的数据集必须要重写getitem(self, index),len(self) 两个内建方法,用来表示从索引到样本的映射(Map).
这样一个数据集dataset,举个例子,当使用dataset[idx]命令时,可以在你的硬盘中读取你的数据集中第idx张图片以及其标签(如果有的话);len(dataset)则会返回这个数据集的容量。
自定义类大致是这样的:
class CustomDataset(data.Dataset):#需要继承data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3