对于如何定义自己的Datasets我讲从以下几个方面进行解说
**
1.什么是Datasets?
2.为什么要定义Datasets?
3.如何定义Datasets?
定义Datasets分为以下几个板块:
1)Datasets的源代码及解说
2)Datasets的整体框架及解说
3)自己的Datasets框架及解说
4)DataLoader的使用
5)如何生成txt文件
什么是Datasets?
Datasets是我们用的数据集的库,我们知道pytorch自带多种数据集列如Cifar10数据集就是在pytorch的Datasets的库中的。
为什么要定义Datasets?
Pytorch中有工具函数torch.utils.Data.DataLoader,通过这个函数我们在准备加载数据集使用mini-batch的时候可以使用多线程并行处理,这样可以加快我们准备数据集的速度。Datasets就是构建这个工具函数的实例参数之一。
如何定义Datasets?
Dataset类
Dataset类是Pytorch中图像数据集中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示:
def getitem(self, index):
def len(self):
其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数
这里重点看 getitem函数,getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。
然而,如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。
那么读取自己数据的基本流程就是:
制作存储了图片的路径和标签信息的txt
将这些信息转化为list,该list每一个元素对应一个样本
通过getitem函数,读取数据和标签,并返回数据和标签
定义自己的数据集类
1)Datasets的源代码及解说
All datasets are subclasses of torch.utils.data.Dataset i.e,
they have __getitem__ and __len__ methods implemented.
Hence, they can all be passed to a torch.
utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers.
[源代码地址(https://pytorch.org/docs/stable/torchvision/datasets.html)
从源代码我们可以看出继承Datasets必须继承__init_()和__getitim__()
首先继承上面的dataset类。然后在__init__()方法中得到图像的路径,然后将图像路径组成一个数组,这样在__getitim__()中就可以直接读取.
2)Datasets的整体框架及解说
class FirstDataset(data.Dataset):#需要继承data.Dataset
def __init__(self):
# TODO
# 1. 初始化文件路径或文件名列表。
#也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
pass
def __getitem__(self, index):
# TODO
#1。从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
#2。预处理数据(例如torchvision.Trans