dataset__getitem___pytorch中继承Dataset自定义数据读取流

v2-51fdff470a64691a21c60e53ba828062_1440w.jpg?source=172ae18b

pytorch虽然简单易用,但是其高度的封装使得初学者难以理解数据是如何读入的。对于自己的任务,很可能pytorch提供的数据读取机制难以完全满足任务要求,所以我们需要学习如何使用pytorch提供的torch.utils.data.Dataset来自定义数据读取流程(文末附完整代码)。下面来分析一下Dataset类的源码【1】:

class 

也就是说我们只需要实现“__getitem__(self, index)”方法即可,这个__getitem__()方法的index参数看起来有些让人困惑,查阅python官方文档【2】发现这个方法是object的方法:

v2-a9d4b74b62db2c34ddb9538529fd147d_b.jpg

所以这个index的值就是索引值,比如我们想要数据集中第二张图片索引就是2,关于__getitem__()的具体的介绍可以看【3】。

一、自定义SingeClassDataset

我在自定义Dataset时共分为3个步骤,目的是以后能够好的进行功能扩展:

  • 初始化图像路径模块__init__()
  • 图像转Tensor模块_read_convert_image()
  • 数据索引模块__getitem__()

由于目的是继承Dataset类,所以应该采用__init__()来存储数据路径,在存储之前要先检查输入的路径是不是正确的路径,防止图片读取失败:

def 

读取图像采用python内置的PIL提供的Image类型,这也是pytorch支持的核心类型。读取Image类型的图片后可以直接通过torchvision提供的变换,转为pytorch需要的Tensor类型:

def 

拥有上述两个方法以后,就可以实现完整的数据读取了。根据图像的存储方式不同可以采用多种读取策略,常见的情况有两种:图像在一个文件夹中、图像在多个文件夹中。下面实现的__getitem__()方法针对于所有的图像在一个文件夹内的情况:

def 

二、测试自定义的SingleClassDataset

测试之前首先准备好数据放到“data”文件夹中,如下:

v2-28784e92ca39dc6d4337eafcd8cd4245_b.png

定义了__getitem__()方法的类就可以通过“索引”获得数据,下面来看一下数据是正确的读入了,可视化采用的是matplotlib,下面的代码展示了如何可视化前16张图片:

import 

得到了下面的结果,就说明数据读取是没问题的:

v2-ec42f13b024003c1f65012675ee054a6_b.jpg

三、完整代码

由于之前pytorch的版本还必须实现"__len__()"方法用于返回数据集的长度,所以下面的代码实现了它,但是当前版本的pytorch已经不再强制实现这个函数了,整体代码如下:

from 

参考:

【1】https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataset.py

【2】https://docs.python.org/3/reference/datamodel.html#object.__getitem__

【3】https://zhuanlan.zhihu.com/p/87786297

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值