图像语义分割和目标检测(中)

本文介绍了如何使用VGG19网络搭建并训练自己的全卷积语义分割模型,涵盖了数据预处理、图像切分、标准化、类别映射、数据加载以及训练过程。作者详细展示了从读取VOC2012数据集,到创建MyDataset类,再到使用DataLoader进行训练的步骤。
摘要由CSDN通过智能技术生成

上一篇介绍的是使用与训练好的语义分割网络segmentation.fcn_resnet101(),对任意输入图像进行语义分割,该模型是以101层的ResNet网络为基础,全卷积语义分割模型。下面将基于VGG19网络,搭建、训练和测试自己的图像全卷积语义分割网络。

由于资源有限,将基于2012年VOC数据集对网络进行训练,主要使用该数据集的训练集和测试集,训练集用于训练网络,验证集防止网络过拟合。每个数据集越有1000张图片,并且图像之间的尺寸不完全相同,数据集共有21类需要学习的目标类别。下面首先导入本小节需要的库和模块,程序如下:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import PIL
from PIL import Image
from time import time
import os
from skimage.io import imread
import copy
import time
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torch.utils.data as Data
from torchvision import transforms
from torchvision.models import vgg19
from torchsummary import summary

上述代码导入了torchsummary库中的summary函数,该函数可以方便查看深度学习网络的结构。

为了使用GPU计算,使用下面的程序定义一个计算设备device,程序如下(本节程序训练和测试均在GPU上完成,在不改动程序的情况下在CPU上也能完成,但训练网络时可能会花费较长的时间)。

在读取数据并对数据进行相关预处理操作之前,先查看数据集。 

针对VOC2012数据集,一共需要分割出的目标类别有21类,其中一类为背景。在标注好的图像中,每类对应的名称和颜色值如下:

classes=['background','aeroplane','bicycle','bird','boat',
         'bottle','bus','car','cat','chair','cow','diningtable',
         'dog','horse','motorbike','person','potted plant',
         'sheep','sofa','train','tv/monitor']
colormap=[[0,0,0],[128,0,0],[0,128,0],[128,128,0],[0,0,128],
          [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
          [64,128,0],[192,128,0],[64,0,128],[192,0,128],
          [64,128,128],[192,128,128],[0,64,0],[128,64,0],
          [0,192,0],[128,192,0],[0,64,128]]

数据预处理需要对每张图像进行如下几个操作:

(1)将原始图像和标记好的图像所对应的图片路径一一对应。

(2)将图像统一切分为固定的尺寸时,需要保持原始图像和其对应的标记好的图像,在切分后每个像素也是一一对应的,所以需要对原始图像和目标的标记图像从相同的位置进行切分。在切分之前还需要过滤掉尺寸小于给定切分尺寸图像。

(3)对原始图像进行数据标准化。

(4)针对标记好的图像,每张图像均是RGB图像,将RGB值对应的类重新定义,把3D的RGB图像转化为一个二维数据,并且数组中每个位置的取值对应着图像在该像素点的类别。

为了完成上述的图像预处理操作,定义下面几个图像数据预处理的辅助函数。

 def image2label(image,colormap):
        cm2lbl=np.zeros(256**3)
        for i,cm in enumerate(colormap):
            cm2lbl[(cm[0]*256+cm[1]*256+cm[2])]=i
        image=np.array(image,dtype="int64")
        ix=(image[:,:,0]*256+image[:,:,1]*256+image[:,:,2])
        image2=cm2lbl[ix]
        return image2

image2label函数可以将一张标记好的图像转化为类别标签图像。

    def rand_crop(data,label,high,width):
        im_width,im_high=data.size
        left=np.random.randint(0,im_width-width)
        top=np.random.randint(0,im_high-high)
        right=left+width
        bottom=top+high
        data=data.crop((left,top,right,bottom))
        label=label.crop((left,top,right,bottom))
        return data,label

rand_crop函数完成对原始图像数据和被标注的标签图像进行随机裁剪任务,随机裁剪后的原图像和标签的每个像素一一对应。可通过参数high和width指定图像裁剪后的宽和高。

    def img_transforms(data,label , high , width,colormap):
        data,label=rand_crop(data,label,high,width)
        data_tfs=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],
                                 [0.229,0.224,0.225])
        ])
        data=data_tfs(data)
        label=torch.from_numpy(image2label(label,colormap))
        return data,label

read_image_path函数是从给定的文件路径中定义出对应的原始图像和标记好的目标图像的存储路径列表。原始图像路径输出为data,标记好的目标图像路径输出为label。

为了将数据定义为数据加载器Data.DataLoader()函数可以接收的数据格式,在定义好上述几个辅助函数后,这需要定义个类操作,该类需要继承torch.utils.data.Dataset类,这样就可以将自己的数据定义为数据加载器操作Data.DataLoader()函数可以接收的数据格式。程序如下所示:

    class MyDataset(Data.Dataset):
        def __init__(self,data_root,high,width,imtransform,colormap):
            self.data_root=data_root
            self.high=high
            self.width=width
            self.imtransform=imtransform
            self.colormap=colormap
            data_list,label_list=read_image_path(root=data_root)
            self.data_list=self._filter(data_list)
            self.label_list=self._filter(label_list)

        def _filter(self,images):
            return [im for im in images if (Image.open(im).size[1]> high and
                                            Image.open(im).size[0]> width)]

        def __getitem__(self, idx):
            img=self.data_list[idx]
            label=self.label_list[idx]
            img=Image.open(img)
            label=Image.open(label).convert('RGB')
            img,label=self.imtransform(img,label,self.high,self.width,self.colormap)
            return img,label

        def __len__(self):
            return len(self.data_list)

在上面定义的类MyDataset包含了一个_filter方法,该方法用于过滤掉图像的尺寸小于固定切分尺寸的样本。在类中每张图像的读取通过Image.open()函数完成。下面使用MyDataset()函数读取数据集的原始数据和对应的标签数据,然后使用Data.DataLoader()函数建立数据加载器,并且每个batch中包含4张图像,程序如下所示:

 high,width=320,480
    voc_train=MyDataset("F:/程序/programs/data/VOC2012/ImageSets/Segmentation/train.txt",
                        high,width,img_transforms,colormap)
    voc_val=MyDataset("F:/程序/programs/data/VOC2012/ImageSets/Segmentation/val.txt",
                      high,width,img_transforms,colormap)
    train_loader=Data.DataLoader(voc_train,batch_size=4,shuffle=True,
                                 num_workers=0,pin_memory=True)
    val_loader=Data.DataLoader(voc_val,batch_size=4,shuffle=True,
                               num_workers=0,pin_memory=True)
    for step,(b_x,b_y) in enumerate(train_loader):
        if step>0:
            break
    # print("b_x.shape:",b_x.shape)
    # print("b_y.shape:",b_y.shape)

运行结果如下:

 从一个batch的图像尺寸输出中可以看出,训练数据中的b_x包含4张320*480的RGB图像,而B_y则包含4张320*480的类别标签数据。下面可以将一个batch的图像和其标签进行可视化,以检查数据是否预处理正确,在可视化之前需要定义两个预处理函数,即inv_normalize_image()和label2image()。

    def inv_normalize_image(data):
        rgb_mean=np.array([0.485,0.456,0.406])
        rgb_std=np.array([0.229,0.224,0.225])
        data=data.astype('float32')*rgb_std+rgb_mean
        return data.clip(0,1)
    def label2image(prelabel,colormap):
        h,w=prelabel.shape
        prelabel=prelabel.reshape(h*w,-1)
        image=np.zeros((h*w,3),dtype="int32")
        for ii in range(len(colormap)):
            index=np.where(prelabel==ii)
            image[index,:]=colormap[ii]
        return image.reshape(h,w,3)

在上面的两个函数中,inv_normalize_image函数用于将标准化后的原始图像进行你标准化操作,可方便对图像数据进行可视化;而label2image函数这是将二维的类别标签数据转化为三维的图像分割后的数据,不同的类别转化为特定的RGB值。下面针对一个batch的图像进行可视化操作,程序如下所示:

b_x_numpy=b_x.data.numpy()
    b_x_numpy=b_x_numpy.transpose(0,2,3,1)
    b_y_numpy=b_y.data.numpy()
    plt.figure(figsize=(16,6))
    for ii in range(4):
        plt.subplot(2,4,ii+1)
        plt.imshow(inv_normalize_image(b_x_numpy[ii]))
        plt.axis("off")
        plt.subplot(2,4,ii+5)
        plt.imshow(label2image(b_y_numpy[ii],colormap))
        plt.axis("off")
    plt.subplots_adjust(wspace=0.1,hspace=0.1)
    plt.show()

运行结果如下:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

mez_Blog

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值