pytorch虽然简单易用,但是其高度的封装使得初学者难以理解数据是如何读入的。对于自己的任务,很可能pytorch提供的数据读取机制难以完全满足任务要求,所以我们需要学习如何使用pytorch提供的torch.utils.data.Dataset来自定义数据读取流程(文末附完整代码)。下面来分析一下Dataset类的源码【1】:
class Dataset(object):
"""此处省略"""
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
也就是说我们只需要实现“__getitem__(self, index)”方法即可,这个__getitem__()方法的index参数看起来有些让人困惑,查阅python官方文档【2】发现这个方法是object的方法:
所以这个index的值就是索引值,比如我们想要数据集中第二张图片索引就是2,关于__getitem__()的具体的介绍可以看【3】。
一、自定义SingeClassDataset
我在自定义Dataset时共分为3个步骤,目的是以后能够好的进行功能扩展:
- 初始化图像路径模块__init__()
- 图像转Tensor模块_read_convert_image()
- 数据索引模块__getitem__()
由于目的是继承Dataset类,所以应该采用__init__()来存储数据路径,在存储之前要先检查输入的路径是不是正确的路径,防止图片读取失败:
def __init__(self, file_path):
# 保证输入的是正确的路径
if not os.path.isdir(file_path):
raise ValueError("input file_path is not a dir")
self.file_path = file_path
# 获取路径下所有的图片名称,必须保证路径内没有图片以外的数据
self.image_list = os.listdir(file_path)
# 将PIL的Image转为Tensor
self.transforms = T.ToTensor()
读取图像采用python内置的PIL提供的Image类型,这也是pytorch支持的核心类型。读取Image类型的图片后可以直接通过torchvision提供的变换,转为pytorch需要的Tensor类型:
def _read_convert_image(self, image_name):
image = Image.open(image_name)
image = self.transforms(image).float()
return image
拥有上述两个方法以后,就可以实现完整的数据读取了。根据图像的存储方式不同可以采用多种读取策略,常见的情况有两种:图像在一个文件夹中、图像在多个文件夹中。下面实现的__getitem__()方法针对于所有的图像在一个文件夹内的情况:
def __getitem__(self, index):
# 根据index获取图片完整路径
image_path = os.path.join(self.file_path, self.image_list[index])
# 都图片并转为Tensor
image = self._read_convert_image(image_path)
return image
二、测试自定义的SingleClassDataset
测试之前首先准备好数据放到“data”文件夹中,如下:
定义了__getitem__()方法的类就可以通过“索引”获得数据,下面来看一下数据是正确的读入了,可视化采用的是matplotlib,下面的代码展示了如何可视化前16张图片:
import matplotlib.pyplot as plt
MyDataset = SingleClassDataset(file_path="data/")
plt.figure()
for i in range(16):
plt.subplot(4, 4, i+1)
plt.imshow(MyDataset[i].numpy().transpose(1, 2, 0))
plt.show()
得到了下面的结果,就说明数据读取是没问题的:
三、完整代码
由于之前pytorch的版本还必须实现"__len__()"方法用于返回数据集的长度,所以下面的代码实现了它,但是当前版本的pytorch已经不再强制实现这个函数了,整体代码如下:
from torch.utils.data import Dataset
import os
from PIL import Image
import torchvision.transforms as T
class SingleClassDataset(Dataset):
"""
This Dataset only work for a folder that contains one class image!!!
"""
def __init__(self, file_path):
# 保证输入的是正确的路径
if not os.path.isdir(file_path):
raise ValueError("input file_path is not a dir")
self.file_path = file_path
# 获取路径下所有的图片名称,必须保证路径内没有图片以外的数据
self.image_list = os.listdir(file_path)
# 将PIL的Image转为Tensor
self.transforms = T.ToTensor()
def __getitem__(self, index):
# 根据index获取图片完整路径
image_path = os.path.join(self.file_path, self.image_list[index])
# 都图片并转为Tensor
image = self._read_convert_image(image_path)
return image
def _read_convert_image(self, image_name):
image = Image.open(image_name)
image = self.transforms(image).float()
return image
def __len__(self):
return len(self.image_list)
import matplotlib.pyplot as plt
MyDataset = SingleClassDataset(file_path="data/")
plt.figure()
for i in range(16):
plt.subplot(4, 4, i+1)
plt.imshow(MyDataset[i].numpy().transpose(1, 2, 0))
plt.show()
参考:
【1】https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataset.py
【2】https://docs.python.org/3/reference/datamodel.html#object.__getitem__
【3】https://zhuanlan.zhihu.com/p/87786297