一般的深度学习训练模型的搭建框架过程为,导入数据-建立模型-训练与测试/迁移学习,在这篇笔记中,我主要记录了自定义一个自己的数据集过程与迁移学习的方法。对于其中涉及的到的训练过程与测试过程在其他的笔记中已有提到。
对于之前用到的MNIST数据集与Cifar10数据集的导入,其实我们都只是利用了pytorch提供的函数,分别是torchvision.datasets.MNIST与torchvision.datasets.CIFAR10两个函数帮助我们实现了样本数据的导入。但是,当我们需要训练我们自己的数据集时,具体的datasets操作函数便需要我们来编写。
对于我们设计自定义数据集类时,具体有三个步骤:
- 继承torch.utils.data中的Dataset类
- 编写 __ len __ ()函数
- 编写 __ getitem __ ()函数
源码中的Dataset如下:
class Dataset(Generic[T_co]