现在我们已经可以从磁盘中将数据读取出来,并获得一个图像路径列表和一个标签列表,那么如何将已读取出来的数据用于深度学习的训练呢?这里就需要用到Pytorch提供的torch.utils.data.Dataset类。
Pytorch官方教程链接:https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
torch.utils.data.Dataset是专门用来将数据处理成单个样本对的类,如果要实现自己的dataset类,必须要重写3个函数:__ init __ 、__ len __ 和 __ getitem __ 。其中, __ init __ 用于初始化一些必要的成员变量,__ len __ 函数返回数据集的长度,__ getitem __ 最为重要,用于将数据处理成单个样本对。
这里还是分为图像分类和目标检测这2个部分分别实现各自的dataset类。
图像分类篇
图像分类任务主要是对图像中的物体进行分类或者说识别,那么我们首先需要很多图像,另外,为了可以指导模型的训练,那么我们就需要知道物体的类别,也就是图像分类任务的标签,这就是为什么需要图像路径列表和标签列表的原因(对《计算机视觉技巧合集(二)如何读取数据之目标检测篇》问题的回答)。
ClassDataset示例代码如下:
from torch.utils.data import Dataset
import random
import os
import json
from PIL import Image
def read_split_three_data(root: str, train_val_rate: float = 0.8, train_rate: float = 0.8):
# 详细内容见计算机视觉技巧合集(一)如何读取数据
return train_images_path, train_images_label, val_images_path, val_images_label
class ClassDataset(Dataset):
def __init__(self, img_paths, labels, tranforms=None):
# 初始化图像路径变量
self.img_paths = img_paths
# 初始化标签变量
self.labels = labels
# 初始化图像后处理和数据增强函数变量
self.tranforms = tranforms
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
# 打开图像,转成RGB格式
img = Image.open(self.img_paths[idx]).convert('RGB')
# 判断转换是否成功
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[idx]))
# 进行图像后处理和数据增强
if self.tranforms is not None:
img = self.tranforms(img)
# 读取标签
label = self.labels[idx]
return img, label
if __name__ == "__main__":
root = r"G:\datasets\flower_photos"
train_images_path, train_images_label, val_images_path, val_images_label = read_split_three_data(root, train_val_rate=0.8, train_rate=0.75)
# 创建train_dataset
train_dataset = ClassDataset(train_images_path, train_images_label)
# 展示前5个样本对
for index, data in enumerate(train_dataset):
img, label = data
print(img, type(img), label, type(label))
if index == 4:
break
程序运行结果:
通过ClassDataset类实现可以看出,我将读取数据和生成样本对解耦合,相互独立了出来,这样逻辑顺序更加清晰,并且便于实现和调试,读取数据函数只需要返回一个图像路径列表和一个标签列表,而自己的dataset类只需要返回处理好的图像和标签样本对即可。
从train_dataset中包含的数据可以看出现在的图像格式仍然是PIL.Image.Image类型,并不是张量类型,另外,每一张图像的大小都不一样,这样无法将其打包成多维数组,所以还不能用于训练,因此,还需要进行图像的后处理,将图像变形成统一的大小,并从Image类型转换成tensor类型。
图像后处理示例代码如下:
from torch.utils.data import Dataset
import random
import os
import json
from PIL import Image
from torchvision import transforms
def read_split_three_data(root: str, train_val_rate: float = 0.8, train_rate: float = 0.8):
# 详细内容见计算机视觉技巧合集(一)如何读取数据
return train_images_path, train_images_label, val_images_path, val_images_label
class ClassDataset(Dataset):
......
return img, label
if __name__ == "__main__":
root = r"G:\datasets\flower_photos"
train_images_path, train_images_label, val_images_path, val_images_label = read_split_three_data(root, train_val_rate=0.8, train_rate=0.75)
# 图像后处理
train_transforms = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor()])
# 创建train_dataset
train_dataset = ClassDataset(train_images_path, train_images_label, train_transforms)
# 展示前5个样本对
for index, data in enumerate(train_dataset):
img, label = data
print(img.dtype, img.shape, label, type(label))
if index == 4:
break
程序运行结果如下:
可以看出,每一张图像的类型都是torch.float32浮点类型,并且size都是3x224x224,这样我们就得到处理好的单个样本对了。
目标检测篇
目标检测任务其实也是同理的,这个任务除了对物体分类还要定位物体的位置,也就是框出物体,那么除了需要物体的类别外,我们当然还需要预先框出物体,有了真实框才可以训练目标检测的模型。这里只实现使用读取VOC类型数据集的load_data_from_txt函数来制作样本对,对于COCO类型数据集的实现是一样的,因此就不一一实现了。
DetectDataset示例代码如下:
import os
import numpy as np
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
class_names = [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ]
def load_data_from_txt(text, img_root, anno_root, remove_difficult=False):
# 详细内容见计算机视觉技巧合集(二)如何读取数据之目标检测篇-补充1
return img_paths, all_labels
class DetectDataset(Dataset):
def __init__(self, img_paths, labels, tranforms=None):
# 初始化图像路径变量
self.img_paths = img_paths
# 初始化标签变量
self.labels = labels
# 初始化图像处理和数据增强函数变量
self.tranforms = tranforms
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
# 打开图像,转成RGB格式
img = Image.open(self.img_paths[idx]).convert('RGB')
# 判断转换是否成功
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[idx]))
# 进行图像预处理和数据增强
if self.tranforms is not None:
img = self.tranforms(img)
# 读取标签
label = self.labels[idx]
return img, label
if __name__ == "__main__":
text_path = r"G:\datasets\VOCdevkit\VOC2012\ImageSets\Main\train.txt"
img_root = r"G:\datasets\VOCdevkit\VOC2012\JPEGImages"
anno_root = r"G:\datasets\VOCdevkit\VOC2012\Annotations"
img_paths, all_labels = load_data_from_txt(text_path, img_root, anno_root, remove_difficult=True)
print(f"图像总数: {len(img_paths)}")
print(f"标签总数: {len(all_labels)}")
train_transforms = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor()])
train_dataset = DetectDataset(img_paths, all_labels, train_transforms)
# 展示前2个样本对
for index, data in enumerate(train_dataset):
img, label = data
print(f"第{index}张图像和对应的标签")
print(img.dtype, img.shape)
print(label, type(label))
if index == 1:
break
程序运行结果如下:
可以看出不论是ClassDataset和DetectDataset返回的图像都只是简单的变形和转换成tensor类型,这样虽然可以作为模型的输入,但由于样本图像比较简单,并且不够多样化,所以模型在训练过程中其实很容易就会过拟合,所以为了获得更加多样更加复杂的样本图像,我们在训练中还需要对其进行数据增强,比如翻转图像、旋转图像、拼接图像和在图像中加入马赛克等等方法。下一篇会细讲如何使用torchvision已有的数据增强方法以及实现更为复杂的数据增强方法。