在上一篇博客PyTorch学习之路(level1)——训练一个图像分类模型中介绍了如何用PyTorch训练一个图像分类模型,建议先看懂那篇博客后再看这篇博客。在那份代码中,采用torchvision.datasets.ImageFolder这个接口来读取图像数据,该接口默认你的训练数据是按照一个类别存放在一个文件夹下。但是有些情况下你的图像数据不是这样维护的,比如一个文件夹下面各个类别的图像数据都有,同时用一个对应的标签文件,比如txt文件来维护图像和标签的对应关系,在这种情况下就不能用torchvision.datasets.ImageFolder来读取数据了,需要自定义一个数据读取接口。另外这篇博客最后还顺带介绍如何保存模型和多GPU训练。
怎么做呢?
先来看看torchvision.datasets.ImageFolder这个类是怎么写的,主要代码如下,想详细了解的可以看:官方github代码。
看起来很复杂,其实非常简单。继承的类是torch.utils.data.Dataset,主要包含三个方法:初始化__init__
,获取图像__getitem__
,数据集数量 __len__
。__init__
方法中先通过find_classes函数得到分类的类别名(classes)和类别名与数字类别的映射关系字典(class_to_idx)。然后通过make_dataset函数得到imags,这个imags是一个列表,其中每个值是一个tuple,每个tuple包含两个元素:图像路径和标签。剩下的就是一些赋值操作了。在__getitem__
方法中最重要的就是 img = self.loader(path)这行,表示数据读取,可以从__init__
方法中看出self.loader采用的是default_loader,这个default_loader的核心就是用python的PIL库的Image模块来读取图像数据。
class ImageFolder(data.Dataset):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes