计算机视觉技巧合集(三)实现自己的dataset类

现在我们已经可以从磁盘中将数据读取出来,并获得一个图像路径列表和一个标签列表,那么如何将已读取出来的数据用于深度学习的训练呢?这里就需要用到Pytorch提供的torch.utils.data.Dataset类。

Pytorch官方教程链接:https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

torch.utils.data.Dataset是专门用来将数据处理成单个样本对的类,如果要实现自己的dataset类,必须要重写3个函数:__ init __ 、__ len __ 和 __ getitem __ 。其中, __ init __ 用于初始化一些必要的成员变量,__ len __ 函数返回数据集的长度,__ getitem __ 最为重要,用于将数据处理成单个样本对。

这里还是分为图像分类和目标检测这2个部分分别实现各自的dataset类。

图像分类篇

图像分类任务主要是对图像中的物体进行分类或者说识别,那么我们首先需要很多图像,另外,为了可以指导模型的训练,那么我们就需要知道物体的类别,也就是图像分类任务的标签,这就是为什么需要图像路径列表和标签列表的原因(对《计算机视觉技巧合集(二)如何读取数据之目标检测篇》问题的回答)。

ClassDataset示例代码如下:

from torch.utils.data import Dataset
import random
import os
import json
from PIL import Image

def read_split_three_data(root: str, train_val_rate: float = 0.8, train_rate: float = 0.8):
    # 详细内容见计算机视觉技巧合集(一)如何读取数据
    return train_images_path, train_images_label, val_images_path, val_images_label

class ClassDataset(Dataset):
    def __init__(self, img_paths, labels, tranforms=None):
        # 初始化图像路径变量
        self.img_paths = img_paths
        # 初始化标签变量
        self.labels = labels
        # 初始化图像后处理和数据增强函数变量
        self.tranforms = tranforms

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        
        # 打开图像,转成RGB格式
        img = Image.open(self.img_paths[idx]).convert('RGB')
        
        # 判断转换是否成功
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[idx]))

        # 进行图像后处理和数据增强
        if self.tranforms is not None:
            img = self.tranforms(img)
        
        # 读取标签
        label = self.labels[idx]

        return img, label

if __name__ == "__main__":
    root = r"G:\datasets\flower_photos"
    train_images_path, train_images_label, val_images_path, val_images_label = read_split_three_data(root, train_val_rate=0.8, train_rate=0.75)

	# 创建train_dataset
    train_dataset = ClassDataset(train_images_path, train_images_label)
    # 展示前5个样本对
    for index, data in enumerate(train_dataset):
        img, label = data
        print(img, type(img), label, type(label))

        if index == 4:
            break

程序运行结果:
在这里插入图片描述
通过ClassDataset类实现可以看出,我将读取数据和生成样本对解耦合,相互独立了出来,这样逻辑顺序更加清晰,并且便于实现和调试,读取数据函数只需要返回一个图像路径列表和一个标签列表,而自己的dataset类只需要返回处理好的图像和标签样本对即可。

从train_dataset中包含的数据可以看出现在的图像格式仍然是PIL.Image.Image类型,并不是张量类型,另外,每一张图像的大小都不一样,这样无法将其打包成多维数组,所以还不能用于训练,因此,还需要进行图像的后处理,将图像变形成统一的大小,并从Image类型转换成tensor类型。

图像后处理示例代码如下:

from torch.utils.data import Dataset
import random
import os
import json
from PIL import Image
from torchvision import transforms

def read_split_three_data(root: str, train_val_rate: float = 0.8, train_rate: float = 0.8):
	# 详细内容见计算机视觉技巧合集(一)如何读取数据
    return train_images_path, train_images_label, val_images_path, val_images_label

class ClassDataset(Dataset):
		......
        return img, label

if __name__ == "__main__":
    root = r"G:\datasets\flower_photos"
    train_images_path, train_images_label, val_images_path, val_images_label = read_split_three_data(root, train_val_rate=0.8, train_rate=0.75)

	# 图像后处理
    train_transforms = transforms.Compose([transforms.Resize((224, 224)),
                                           transforms.ToTensor()])
	# 创建train_dataset
    train_dataset = ClassDataset(train_images_path, train_images_label, train_transforms)
    # 展示前5个样本对
    for index, data in enumerate(train_dataset):
        img, label = data
        print(img.dtype, img.shape, label, type(label))

        if index == 4:
            break

程序运行结果如下:
在这里插入图片描述
可以看出,每一张图像的类型都是torch.float32浮点类型,并且size都是3x224x224,这样我们就得到处理好的单个样本对了。

目标检测篇

目标检测任务其实也是同理的,这个任务除了对物体分类还要定位物体的位置,也就是框出物体,那么除了需要物体的类别外,我们当然还需要预先框出物体,有了真实框才可以训练目标检测的模型。这里只实现使用读取VOC类型数据集的load_data_from_txt函数来制作样本对,对于COCO类型数据集的实现是一样的,因此就不一一实现了。

DetectDataset示例代码如下:

import os
import numpy as np
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class_names = [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
               'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ]

def load_data_from_txt(text, img_root, anno_root, remove_difficult=False):
	# 详细内容见计算机视觉技巧合集(二)如何读取数据之目标检测篇-补充1
    return img_paths, all_labels

class DetectDataset(Dataset):
    def __init__(self, img_paths, labels, tranforms=None):
        # 初始化图像路径变量
        self.img_paths = img_paths
        # 初始化标签变量
        self.labels = labels
        # 初始化图像处理和数据增强函数变量
        self.tranforms = tranforms

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):

        # 打开图像,转成RGB格式
        img = Image.open(self.img_paths[idx]).convert('RGB')

        # 判断转换是否成功
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[idx]))

        # 进行图像预处理和数据增强
        if self.tranforms is not None:
            img = self.tranforms(img)

        # 读取标签
        label = self.labels[idx]

        return img, label


if __name__ == "__main__":

    text_path = r"G:\datasets\VOCdevkit\VOC2012\ImageSets\Main\train.txt"
    img_root = r"G:\datasets\VOCdevkit\VOC2012\JPEGImages"
    anno_root = r"G:\datasets\VOCdevkit\VOC2012\Annotations"
    img_paths, all_labels = load_data_from_txt(text_path, img_root, anno_root, remove_difficult=True)
    print(f"图像总数: {len(img_paths)}")
    print(f"标签总数: {len(all_labels)}")

    train_transforms = transforms.Compose([transforms.Resize((224, 224)),
                                           transforms.ToTensor()])

    train_dataset = DetectDataset(img_paths, all_labels, train_transforms)
    # 展示前2个样本对
    for index, data in enumerate(train_dataset):
        img, label = data
        print(f"第{index}张图像和对应的标签")
        print(img.dtype, img.shape)
        print(label, type(label))

        if index == 1:
            break

程序运行结果如下:
在这里插入图片描述
可以看出不论是ClassDataset和DetectDataset返回的图像都只是简单的变形和转换成tensor类型,这样虽然可以作为模型的输入,但由于样本图像比较简单,并且不够多样化,所以模型在训练过程中其实很容易就会过拟合,所以为了获得更加多样更加复杂的样本图像,我们在训练中还需要对其进行数据增强,比如翻转图像、旋转图像、拼接图像和在图像中加入马赛克等等方法。下一篇会细讲如何使用torchvision已有的数据增强方法以及实现更为复杂的数据增强方法。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值