pytorch虽然简单易用,但是其高度的封装使得初学者难以理解数据是如何读入的。对于自己的任务,很可能pytorch提供的数据读取机制难以完全满足任务要求,所以我们需要学习如何使用pytorch提供的torch.utils.data.Dataset来自定义数据读取流程(文末附完整代码)。下面来分析一下Dataset类的源码【1】:
class
也就是说我们只需要实现“__getitem__(self, index)”方法即可,这个__getitem__()方法的index参数看起来有些让人困惑,查阅python官方文档【2】发现这个方法是object的方法:
所以这个index的值就是索引值,比如我们想要数据集中第二张图片索引就是2,关于__getitem__()的具体的介绍可以看【3】。
一、自定义SingeClassDataset
我在自定义Dataset时共分为3个步骤,目的是以后能够好的进行功能扩展:
- 初始化图像路径模块__init__()
- 图像转Tensor模块_read_convert_image()
- 数据索引模块__getitem__()
由于目的是继承Dataset类,所以应该采用__init__()来存储数据路径,在存储之前要先检查输入的路径是不是正确的路径,防止图片读取失败:
def
读取图像采用python内置的PIL提供的Image类型,这也是pytorch支持的核心类型。读取Image类型的图片后可以直接通过torchvision提供的变换,转为pytorch需要的Tensor类型:
def
拥有上述两个方法以后,就可以实现完整的数据读取了。根据图像的存储方式不同可以采用多种读取策略,常见的情况有两种:图像在一个文件夹中、图像在多个文件夹中。下面实现的__getitem__()方法针对于所有的图像在一个文件夹内的情况:
def
二、测试自定义的SingleClassDataset
测试之前首先准备好数据放到“data”文件夹中,如下:
定义了__getitem__()方法的类就可以通过“索引”获得数据,下面来看一下数据是正确的读入了,可视化采用的是matplotlib,下面的代码展示了如何可视化前16张图片:
import
得到了下面的结果,就说明数据读取是没问题的:
三、完整代码
由于之前pytorch的版本还必须实现"__len__()"方法用于返回数据集的长度,所以下面的代码实现了它,但是当前版本的pytorch已经不再强制实现这个函数了,整体代码如下:
from
参考:
【1】https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataset.py
【2】https://docs.python.org/3/reference/datamodel.html#object.__getitem__
【3】https://zhuanlan.zhihu.com/p/87786297