最详细的语义分割---01如何读取数据集?


数据集组成


网络训练的第一步就是读取数据,关于输入图片如何读取,如何进行预处理,将会在本篇文章中进行演示。
首先需要了解的是,语义分割中图片和标签是分别保存的。以voc数据集为例,它有20个类别,加上背景总共21个类别。其中,JPEGImages文件夹下存放的是输入图片,它们都是JPG格式。每张图片都是R,G,B三通道,其像素值在0-255之间。
在这里插入图片描述

SegmentationClass文件夹下存放的是标签,它们都是PNG格式。每张标签都是单通道的,其像素值0-N之间,其中N为分类的类别数。至于为什么单通道的图片看起来还是彩色的,这其实是通过***伪色码***显示的,本质上还是单通道。特别需要注意的是,标签中的背景,也就是图中黑色的部分,它的像素值是255。
在这里插入图片描述

光有图片和标签是不够的,我们还不知道那些图片需要训练,那些图片需要验证。所以在voc数据集中还有一个文件是用来区分那些是训练用的图片,那些是预测用的图片。
在这里插入图片描述

它们都是以txt文件储存,点击进去,会发现里面全部都是图片的名称,它们都没有后缀。这些是制作这个数据集的作者为准备好用来训练的图片。
在这里插入图片描述

datasets的搭建

在pytorch中,训练模型需要将图片和标签读入对应的类当中,这个类就叫做dataset。我们读取自己的数据集的时候只需要重写这个类就可以了,特别的,我们自定义的这个类必须继承pytorch官方定义的Dataset这个父类。下图为自定义的voc dataset,它有很多类属性,例如self.root就是存放图片的根路径。每个属性的作用都在下面批注了注释,这里就不具体介绍了。
在这里插入图片描述

在自定义数据集中,我们肯定要告诉程序,我们的图片、标签存放在哪里,它应该如何读取,所以我们自定义了一个函数,它的名字叫set_files,我们在初始化的时候就执行了它。其中函数的功能如下图所示,通过这个函数,我们会得到图片和标签的根路径,它们分别会存放在self.image_dir和self.label_dir中,我们在这个类里面可以随时调用这个类属性。之前提到过,除了图片和标签,我们还需知道那些图片进行训练,那些图片进行预测。所以我们self.files中读txt文件,当我们为训练模式的时候,读取的是train.txt,当我们是验证模式的时候读取是val.txt。
在这里插入图片描述

前面我们只是得到了图片的路径,并将它们赋值给类变量,我们并没有对它们进行读取,所以我们需要一个函数来将它们读取,具体函数见下图。它传入一个index索引,通过这个索引,我们就可以从self.files中拿出我们需要训练的图片的名称,再根据之前的到的图片和标签的根路径,将名称与它们拼接,就可以得到一个完整的图片路径。我们通过Image.open函数打卡这张图片同时将它转换为一个数组,方便我们后续对它进行处理。最后返回数组形式的图片和标签。
在这里插入图片描述

两个重要的魔法方法

上面的操作是们自定义的,但是如果要实现Dataset的功能,我们就必须要重写两个方法__getitem__和__len__。其中__getitem__必须有index参数,因为这个参数控制着代码当前读取那一张图片。通过下面的代码可以发现,我们将这个index传给上面的讲到的函数,拿到具体的一张图片,并判断是否对他进行数据增强操作,同时也会统一数据的格式,最后一定要返回处理好的图片和标签。
在这里插入图片描述

__len__方法也是必须写的,因为,程序无法知道这个数据集有多少张图片,应该迭代多少次结束。所以我们要重写__len__方法,就是要告诉程序数据集的长度,它写起来也是非常简单。
在这里插入图片描述

数据增强

细心的小伙伴肯定发现__getitem__函数里面调用了self._augmentation()函数,这个函数的作用是对图片进行数据增强,由于我们已经得到了数组形式的图片和标签,那么这里对它进行数据增强已经非常简单了,因为已经有现成的库帮我们实现了这些功能,我们只用调用就好了。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

完整代码

# Originally written by Kazuto Nakashima
# https://github.com/kazuto1011/deeplab-pytorch

import numpy as np
import os
from torch.utils.data import Dataset
import torch
from PIL import Image
from torchvision import transforms
import cv2
import random
class VOCDataset(Dataset):
    def __init__(self, root, split='train',num_classes=21, base_size=None, augment=True,
                 crop_size=321, scale=True, flip=True, rotate=True, blur=True,):
        super(VOCDataset, self).__init__()
        self.root = root     # 存放数据集的根路径
        self.num_classes = num_classes  # 数据集的类别总数
        self.MEAN = [0.45734706, 0.43338275, 0.40058118] # 数据集的均值和方差
        self.STD = [0.23965294, 0.23532275, 0.2398498]
        self.crop_size = crop_size  #裁剪图片的大小
        self.scale = scale          #是否进行scale
        self.flip = flip            #是否进行flip
        self.rotate = rotate        #是否进行rotate
        self.blur = blur            # 是否进行blur
        self.base_size = base_size  # 基础读入图片大小
        self.augment = augment  #是否进行数据增强
        self.split = split  # 拿到训练模式
        self._set_files()   # 调用函数,拿到所有训练 验证的图片名字
        self.to_tensor = transforms.ToTensor() # 对图片进行归一化处理
        self.normalize = transforms.Normalize(self.MEAN,self.STD)

    def _set_files(self):
        self.root = os.path.join(self.root, 'VOC2012')  # VOC数据集的路径
        self.image_dir = os.path.join(self.root, 'JPEGImages') # 图片的存放路径
        self.label_dir = os.path.join(self.root, 'SegmentationClass') # 标签的存放路径
        file_list = os.path.join(self.root, "ImageSets/Segmentation", self.split + ".txt")
        # 训练或验证图片的名称txt文件
        self.files = [line.rstrip() for line in tuple(open(file_list, "r"))] # 训练或验证图片的名称 放入列表
        # 这里拿到的是对应的图片的名字  放在列表中

    def _load_data(self, index):
        image_id = self.files[index]  # 根据索引取图片
        image_path = os.path.join(self.image_dir, image_id + '.jpg') # 图片路径
        label_path = os.path.join(self.label_dir, image_id + '.png') # 标签路径
        # 将图片转成数组
        image = np.asarray(Image.open(image_path), dtype=np.float32)
        label = np.asarray(Image.open(label_path), dtype=np.int32)
        return image, label

    def __getitem__(self, index):
        "__getitem__方法在自定义数据集的时候必须重写.index是输入图片的索引值"
        "在这个函数里面可以对图片进行预处理,但是要返回处理好的图片"
        image, label = self._load_data(index)   # 拿到每一张图片和标签
        if self.augment: # 判断是否进行数据增强
            image, label = self._augmentation(image, label)
        # 统一输入图片格式
        label = torch.from_numpy(np.array(label, dtype=np.float32)).long()
        image = Image.fromarray(np.uint8(image))
        return self.normalize(self.to_tensor(image)), label # 归一化 将图片转换为tensor对象

    def __len__(self):
        "__len__方法在自定义数据集时候必须重写.返回数据集的长度"
        return len(self.files)

    #数据增强函数
    def _augmentation(self, image, label):
        h, w, _ = image.shape
        if self.base_size:
            if self.scale:
                longside = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
            else:
                longside = self.base_size
            h, w = (longside, int(1.0 * longside * w / h + 0.5)) if h > w else (
            int(1.0 * longside * h / w + 0.5), longside)
            image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
            label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST)

        h, w, _ = image.shape
        # 旋转图片在(-10°和10°之间)
        if self.rotate:
            angle = random.randint(-10, 10)
            center = (w / 2, h / 2)
            rot_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
            image = cv2.warpAffine(image, rot_matrix, (w, h),
                                   flags=cv2.INTER_LINEAR)
            label = cv2.warpAffine(label, rot_matrix, (w, h),
                                   flags=cv2.INTER_NEAREST)

        # 对不符合指定大小的图片进行裁剪
        if self.crop_size:
            pad_h = max(self.crop_size - h, 0)
            pad_w = max(self.crop_size - w, 0)
            pad_kwargs = {
                "top": 0,
                "bottom": pad_h,
                "left": 0,
                "right": pad_w,
                "borderType": cv2.BORDER_CONSTANT, }
            if pad_h > 0 or pad_w > 0:
                image = cv2.copyMakeBorder(image, value=0, **pad_kwargs)
                label = cv2.copyMakeBorder(label, value=0, **pad_kwargs)

            # 对不符合大小的图片进行padding
            h, w, _ = image.shape
            start_h = random.randint(0, h - self.crop_size)
            start_w = random.randint(0, w - self.crop_size)
            end_h = start_h + self.crop_size
            end_w = start_w + self.crop_size
            image = image[start_h:end_h, start_w:end_w]
            label = label[start_h:end_h, start_w:end_w]

        # 随机反转
        if self.flip:
            if random.random() > 0.5:
                image = np.fliplr(image).copy()
                label = np.fliplr(label).copy()

        # 给图片增加高斯噪音
        if self.blur:
            sigma = random.random()
            ksize = int(3.3 * sigma)
            ksize = ksize + 1 if ksize % 2 == 0 else ksize
            image = cv2.GaussianBlur(image, (ksize, ksize), sigmaX=sigma, sigmaY=sigma,
                                     borderType=cv2.BORDER_REFLECT_101)
        return image, label





在这里插入图片描述

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值