前言
在使用深度学习进行图片分类任务时, 往往都是用PyTorch进行模型搭建等操作. 如果需要进行的训练任务有现成的图片库, 那么如何将图片变成可以被PyTorch平台所承认的数据集呢? 这就需要自己创建可用的数据集, 就像很多案例中使用PyTorch自带的MNIST, Fashion-MNIST等数据集一样.
Dataset类
Dataset类是数据集的基类, 要想创建自己的数据集就需要自己创建数据集类集成并实现这个基类.
Dataset类位于PyTorch包中, 具体的位置为:
from torch.utils.data import Dataset
在utils.data
包也有加载器dataloader.
Dataset类的定义如下:
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
其中给出了两个函数需要实现, 一个是__getitem__
并且是根据下标获取; 另一个则是__add__
, 如果只是简单进行分类任务的话, 实现__getitem__
函数即可.
数据集类模板
首先创建自己的数据集类, 之后这个数据集类要实现为一个对象从中获取数据集.
class MyDataset(Dataset):
def __init__(self, img_path):
pass
def __getitem__(self, idx):
pass
def __len__(self):
pass
这是自定义数据集的模板, 实现这三个函数就可以完成数据集的简单定义.
__init__(self, img_path)
这个函数用来初始化一些与数据集相关的参数. 比如这个数据集的路径信息. 通过这个路径信息可以将数据集中的所有图片都保存到这个类中. 此外, 由于必须要实现的函数__getitem__
是通过下标进行获取图片的, 那么势必要将整个数据集中的图片转换为列表的形式, 从而根据下标进行图片的获取.__getitem__(self, idx)
这个函数是必须要实现的, 这是从数据集中获取数据的接口. 因为要进行训练任务, 那么只返回训练数据(也就是图片)是不够的, 还需要返回相应的标签.__len__(self)
这个数据可以帮助我们查看数据集中样本的总数量, 毕竟整个数据集的图片都被放到了一个列表结构中, 获取整体的长度也很简单.
实现细节
首先做一个假设. 假设你的数据集中存储的格式是这样的
dataset
|-label_1
|--img_01.png
|--img_02.png
|-- ...
|-label_2
|--img_01.png
|--img_02.png
|-- ...
|-label_3
|- ...
就是说dataset文件夹下有许多子文件夹, 这些子文件夹中, 一个子文件夹就是单独一类, 每一类的图片就存放在对应的子文件夹下.
按照这个假设, 如何把整个数据集的图片用列表的形式保存起来, 同时也要保存对应的标签名呢?
在这里我只提供一个思路: 创建一个类, 保存图片的路径也保存图片的标签.
类似于:
class Img():
def __init__(self, img_path, img_label):
self.img_path = img_path
self.img_label = img_label
在保存信息的同时也可以保存其他的信息. 将图片组织成这样的对象存储起来, 然后在后面获取的时候再根据图片的路径创建图片, 再结合标签返回就可以了.
例如:
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
class Img():
def __init__(self, img_path, img_label):
self.img_path = img_path
self.img_label = img_label
class MyDataset(Dataset):
# 这里的dirlist是数据集文件夹的地址
# 下面的子文件夹才是各个类
def __init__(self, dirlist):
self.dirlist = dirlist
self.label_list = os.listdir(self.dirlist)
self.img_obj_list = [] # 存储Img对象
for label in self.label_list:
# 每一个label都是一个子文件夹名称
# 和dirlist拼接起来才是图片实际存储的地址
self.img_dir_path = os.path.join(self.dirlist, label)
self.img_list = os.listdir(self.img_dir_path)
for img_path in self.img_list:
img_obj_list.append(Img(os.path.join(self.img_dir_path, img_path), label))
def __getitem__(self, idx):
img_obj = self.img_obj_list[idx]
img = Image.open(img_obj.img_path)
# 因为导入pytorch中的图片格式必须是tensor或者ndarray, 所以要转换
# 但要注意: ndarray格式的图片通道为[H, W, C]而非pytorch常用的[C, H, W]
img = np.array(img)
label = img_obj.img_label
return img, label
def __len__(self):
return len(self.img_obj_list)
至此, 就简单完成了数据集的制作