基于PyTorch的FCN-8s语义分割模型搭建

对语义分割的学习进行总结:

一、选取数据集

初步学习语义分割,选取VOC2012数据集,该数据集分为21类,其中20类为前景物体,1类为背景。地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

数据集下载后,对于语义分割网络主要用到"ImageSets","JPEGImages","SegmentationClass"三个文件夹,ImageSets\Segmentation包含训练集和验证集的图片名称,JPEGImages中包含训练集原始图片,SegmentationClass包含已经标注好的图像。“SegmentationObject”文件夹中包含实例分割用到的标签图像。

1.定义GPU或者cpu

2..准备数据集

(1)制作标签:引入标注好的图像,设置list来存放标签对应的类别信息,每个像素值对应一个标签名称

准备数据集
#列出每个物体对应背景的RGB值
classes =['background','aetoplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable',
         'dog','horse','motorbike','person','potted plant','sheep','sofa','train','tv/monitor']
#每个类的RGB值
colormap =[[0,0,0],[128,0,0],[0,128,0],[128,128,0],[0,0,128],[128,0,128],[0,128,128],[128,128,128],
          [64,0,0],[192,0,0],[64,128,0],[192,128,0],[64,0,128],[192,0,128],[64,128,128],[192,128,128],
          [0,64,0],[128,64,0],[0,192,0],[128,192,0],[0,64,128]]

方法:定义一个一维向量,含有256^3个元素,目的是为了让三通道图像的每一点像素特征都有所对应

cm2lbl =np.zeros(256**3)

 对于标签的的颜色进行分类:

for i,cm in enumerate(colormap):
    cm2lbl[((cm[0]*256+cm[1])*256+cm[20])] =i

方法:保证标签的每一种像素值都有对应的类别,具体见我之前的文章:

Pytorch语义分割理解_视觉菜鸟Leonardo的博客-CSDN博客

上述操作将cm2lbl向量中所有与标签关联的元素全都变为类别(0-20)

下一步判断输入的图像中各像素值对应cm2lbl的位置关系,从而对应到相应位置上的类别(0-20),类别再对应到相关的RGB值(colormap)

image =np.array(image,dtype="int64")
cm2lbl[(image(:,:,0)*256+image(:,:,1))*256+image(:,:,2)] =image2

 (2)对于输入图像和标签大小保证一致,并且将二者图像切分为统一大小时,需要保证前后图像对应像素点的位置不变(按照同一比例同一尺寸裁剪),同时为了增加数据量,采用随机裁剪的方法,random_crop

img_width,img_high =image.size
width,high=a,b #规定尺寸
left = np.random.randint(0,img_width-width)
top =np.random.randint(0,image_high-high)
right =left+width
bottom =top+high
#进行裁剪
image2 =image.crop(left,top,right,bottom)

 用上述方法可以得到随机剪裁后、尺寸相同、像素点对应不变的输入图像和标签图像。

以上两个方法是定义在两个函数里的,都属于图像预处理操作的一部分,下面要定义一个函数来整合进行图像预处理

对于导入的图像首先进行剪裁,调用randon_crop函数

data,label =ramdom_crop(data,label,width,high)

 剪裁后的图像需要进行标准化处理,使用transform.compose()函数

data_trans =transforms.Compose(transforms.ToTensor(),
            transforms.Normalize([a,b,c],[d,e,f]))

 图像进行标准化处理后,需要对标签进行分类,使用imagelabel函数

label= torch.from_numpy(imagelabel(label,colormap))

注意,这里之前都是用的numpy格式,需要转化成torch

经过上述处理,标签分类和输入图像数据转化完成,但是还缺少图像引入函数

引入图像:通过和之前做线结构光导入图像文件的原理相同,提前做好两个txt文件,里面分别存放数据集和标签图像的路径,通过导入路径来实现读取图像的功能

root= xxx
image =np.loadtxt(road,dtype =str)
n =len(image)
data,label = [None] *n
for i,name in enumerate(image):
    data[i] ="xxxxxxx/%s.jpg" %name
    label[i] ="xxxxxx/%s.png" %name

3.定义数据格式转换类

需要定义一个类,调用上述所有辅助函数,将数据转换成DataLoader()函数可以接受的数据格式,

#继承于Dataset类
class MyDataset(Data.Dataset):
    def __init__(self,data_root,high,width,imtransform,colormap):
        #进行初始化,需要对全部辅助函数进行调用
        self.data_root =data_root
        self.high =high
        self.width = width
        self.imtransform =imtransform
        self.colormap =colormap
        #读取数据集文件名和标签文件名,调用上面函数
        data_list,label_list =read_image_path(root=data_root)
        self.data_list =self._filter(data_list)
        self.label_list =self._filter(label_list)
    #过滤掉小于指定大小的图片
     def _filter(self,images):
           return [im for im in images if (Image.open(im).size[1] > high and Image.open(im).size[0] > width)]
#根据索引读取
     def __getitem__(self,idx):
        img =self.data_list[idx]
        label =self.label_list[idx]
        img =Image.open(img)
        label =Image.open(label).convert('RGB')
        img,label=self.transform(img,label,self.high,self.width,self.colormap)
        return img,label
      def __len__(self):
        return len(self,data_list)
     

通过定义的类完成了数据集预处理,下面需要导入到Dataset中进行数据加载

 voc_train和voc_val直接调用类完成数据集预处理

#读取数据
high,width =320,480
voc_train =MyDataset("D:\\DataSets\\PascalVOC2012\\VOCdevkit\\VOC2012\\ImageSets\\Segmentation\\train.txt",high,width,img_transform,colormap)
voc_val =MyDataset("D:\\DataSets\\PascalVOC2012\\VOCdevkit\\VOC2012\\ImageSets\\Segmentation\\val.txt",high,width,img_transform,colormap)
#创建数据加载器中每个batch使用4张图像
train_loader =Data.DataLoader(voc_train,batch_size=4,shuffle=True,num_workers=2,pin_memory=True) #锁业内存,加快速度
val_loader =Data.DataLoader(voc_val,batch_size =4,shuffle=True,num_workers=2,pin_memory=True)
#检查训练数据集的一个batch样本维度是否正确
for step,(b_x,b_y) in enumerate(train_loader):
    if step>0:
        break

上述代码输出a个batch,下面需要进行可视化来检验预处理结果是否正确,首先定义两个预处理函数: 

# 将标准化后的数据转化为0-1的区间
def inv_normalize_image(data):
# 将标准化的图像进行逆标准化操作,转为能够可视化的 0−1 区间
#mean 和 std用来凸显个体差异
    rgb_mean = np.array([0.485, 0.456, 0.406])
#mean = [123.680, 116.779,103.939 ]  #RGB 图像范围0-255时
    rgb_std = np.array([0.229, 0.224, 0.225])
    data = data.astype('float32') * rgb_std + rgb_mean
    return data.clip(0, 1)
# 从预测的标签转化为图像的操作
def label2_image(prelabel, colormap):
#label2image 是将二维的类别标签矩阵转为三维的图像分割后的数据,是image2label的逆操作。不同的类别转#化为特定的 RGB 值。
    # 预测的标签转化为图像,针对一个标签图
    h, w = prelabel.shape
    prelabel = prelabel.reshape(h * w, -1)
    image = np.zeros((h * w, 3), dtype="int32")
    for ii in range(len(colormap)):
        index = np.where(prelabel == ii)# 标签对应的colormap索引
        image[index, :] = colormap[ii]
    return image.reshape(h, w, 3)

 下面进行可视化:

##可视化一个batch的图像,检查数据预处理是否正确
b_x_numpy =b_x.data.numpy()
b_x_numpy =b_x_numpy.transpose(0,2,3,1)#将通道数放到最后
b_y_numpy =b_y.data.numpy()
plt.figure(figsize=(16,6))
for ii in range(4):
    plt.subplot(2,4,ii+1)
    plt.imshow(inv_normalize_image(b_x_numpy[ii]))
    plt.axis("off")
    plt.subplot(2,4,ii+5)#ii+5代表从横坐标开始从左往右的位置
    plt.imshow(label2_image(b_y_numpy[ii],colormap))
    plt.axis("off")
plt.subplots_adjust(wspace=0.1,hspace =0.1)
plt.show()

结果如下图: 

 

二、搭建网络

FCN网络作为全卷积神经网络,舍弃了传统网络后面的全连接层,将其转换为卷积层,本文使用的FCN-8s是以VGG19改动而来,舍弃了最后的平均值池化层和全连接层,首先导入VGG19网络:

model_vgg19 =vgg19(pretrained =True)
base_model =model_vgg19.features
summary(base_model,(3,high,width)#显示每层大小

结果如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 320, 480]           1,792
              ReLU-2         [-1, 64, 320, 480]               0
            Conv2d-3         [-1, 64, 320, 480]          36,928
              ReLU-4         [-1, 64, 320, 480]               0
         MaxPool2d-5         [-1, 64, 160, 240]               0
            Conv2d-6        [-1, 128, 160, 240]          73,856
              ReLU-7        [-1, 128, 160, 240]               0
            Conv2d-8        [-1, 128, 160, 240]         147,584
              ReLU-9        [-1, 128, 160, 240]               0
        MaxPool2d-10         [-1, 128, 80, 120]               0
           Conv2d-11         [-1, 256, 80, 120]         295,168
             ReLU-12         [-1, 256, 80, 120]               0
           Conv2d-13         [-1, 256, 80, 120]         590,080
             ReLU-14         [-1, 256, 80, 120]               0
           Conv2d-15         [-1, 256, 80, 120]         590,080
             ReLU-16         [-1, 256, 80, 120]               0
           Conv2d-17         [-1, 256, 80, 120]         590,080
             ReLU-18         [-1, 256, 80, 120]               0
        MaxPool2d-19          [-1, 256, 40, 60]               0
           Conv2d-20          [-1, 512, 40, 60]       1,180,160
             ReLU-21          [-1, 512, 40, 60]               0
           Conv2d-22          [-1, 512, 40, 60]       2,359,808
             ReLU-23          [-1, 512, 40, 60]               0
           Conv2d-24          [-1, 512, 40, 60]       2,359,808
             ReLU-25          [-1, 512, 40, 60]               0
           Conv2d-26          [-1, 512, 40, 60]       2,359,808
             ReLU-27          [-1, 512, 40, 60]               0
        MaxPool2d-28          [-1, 512, 20, 30]               0
           Conv2d-29          [-1, 512, 20, 30]       2,359,808
             ReLU-30          [-1, 512, 20, 30]               0
           Conv2d-31          [-1, 512, 20, 30]       2,359,808
             ReLU-32          [-1, 512, 20, 30]               0
           Conv2d-33          [-1, 512, 20, 30]       2,359,808
             ReLU-34          [-1, 512, 20, 30]               0
           Conv2d-35          [-1, 512, 20, 30]       2,359,808
             ReLU-36          [-1, 512, 20, 30]               0
        MaxPool2d-37          [-1, 512, 10, 15]               0
================================================================
Total params: 20,024,384
Trainable params: 20,024,384
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.76
Forward/backward pass size (MB): 729.49
Params size (MB): 76.39
Estimated Total Size (MB): 807.64
----------------------------------------------------------------

 上述代码引入了VGG19网络,并抛弃了全连接层,下面需要为网络后续添加卷积层:

class FCN8s(nn.Module):
    def __init__(self,num_classes):
        super(FCN8s,self).__init__()
        self.num_classes =num_classes
        model_vgg19 =vgg19(pretrained =True)
        self.base_model =mddel_vgg19.features
        self.relu =nn.ReLU(inplace =True)
        self.deconv1 =nn.convTranspose2d(512,512,kernel_size =3,stride =2,
padding =1,dilation =1,output_padding =1)
        self.bn1 =nn.BatchNorm2d(512)
        self.deconv2 =nn.ConvTranspose2d(512,256,3,2,1,1,1)
        self.bn2 =nn.BatchNorm2d(256)
        self.deconv3 =nn.ConvTranspose2d(256,128,3,2,1,1,1)
        self.bn3 =nn.BatchNorm2d(128)
        self.deconv4 =nn.ConvTranspose2d(128,64,3,2,1,1,1)
        self.bn4 =nn.BatchNorm2d(64)
        self.deconv5 =nn.ConvTranspose2d(64,32,3,2,1,1,1)
        self.bn5 =nn.BatchNorm2d(32)
        self.classifier =nn.Conv2d(32,num_classes,kernel_size=1)
        #VGG19中MaxPool2d所在的层
        self.layers=                         
       {"4":"maxpool_1","9":"maxpool_2","18":"maxpool_3","27":"maxpool_4","36":"maxpool_5"}
    def forward(self,x):
        output={}
        for name,layer in self.base_model._modules.items():
            x =layer(x)
            if name in self.layses:
            output[self,layers(name)] =x
        x5 =output["maxpool_5"]
        x4 =output["maxpool_4"]
        x3 =output["maxpool_3"]
        #对特征进行相关的反卷积操作,恢复为原始图像大小
        score =self.relu(self.deconv1(x5))
    #对应元素相加
        score =self.bn1(score+x4)
        score =self.relu(self.deconv2(score))
        score=self.bn2(score+x3)
        score=self.bn3(self.relu(self.deconv3(score)))
        score = self.bn4(self.relu(self.deconv4(score)))
        score = self.bn5(self.relu(self.deconv5(score)))
        score = self.classifier(score)
        return score

网络搭建完成后,查看网络结构:.

fcn8s =FCN8s(21).to(device)  #21类,GPU
summary(fcn8s,input_size=(3,high,width))

 输出结果:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 320, 480]           1,792
              ReLU-2         [-1, 64, 320, 480]               0
            Conv2d-3         [-1, 64, 320, 480]          36,928
              ReLU-4         [-1, 64, 320, 480]               0
         MaxPool2d-5         [-1, 64, 160, 240]               0
            Conv2d-6        [-1, 128, 160, 240]          73,856
              ReLU-7        [-1, 128, 160, 240]               0
            Conv2d-8        [-1, 128, 160, 240]         147,584
              ReLU-9        [-1, 128, 160, 240]               0
        MaxPool2d-10         [-1, 128, 80, 120]               0
           Conv2d-11         [-1, 256, 80, 120]         295,168
             ReLU-12         [-1, 256, 80, 120]               0
           Conv2d-13         [-1, 256, 80, 120]         590,080
             ReLU-14         [-1, 256, 80, 120]               0
           Conv2d-15         [-1, 256, 80, 120]         590,080
             ReLU-16         [-1, 256, 80, 120]               0
           Conv2d-17         [-1, 256, 80, 120]         590,080
             ReLU-18         [-1, 256, 80, 120]               0
        MaxPool2d-19          [-1, 256, 40, 60]               0
           Conv2d-20          [-1, 512, 40, 60]       1,180,160
             ReLU-21          [-1, 512, 40, 60]               0
           Conv2d-22          [-1, 512, 40, 60]       2,359,808
             ReLU-23          [-1, 512, 40, 60]               0
           Conv2d-24          [-1, 512, 40, 60]       2,359,808
             ReLU-25          [-1, 512, 40, 60]               0
           Conv2d-26          [-1, 512, 40, 60]       2,359,808
             ReLU-27          [-1, 512, 40, 60]               0
        MaxPool2d-28          [-1, 512, 20, 30]               0
           Conv2d-29          [-1, 512, 20, 30]       2,359,808
             ReLU-30          [-1, 512, 20, 30]               0
           Conv2d-31          [-1, 512, 20, 30]       2,359,808
             ReLU-32          [-1, 512, 20, 30]               0
           Conv2d-33          [-1, 512, 20, 30]       2,359,808
             ReLU-34          [-1, 512, 20, 30]               0
           Conv2d-35          [-1, 512, 20, 30]       2,359,808
             ReLU-36          [-1, 512, 20, 30]               0
        MaxPool2d-37          [-1, 512, 10, 15]               0
  ConvTranspose2d-38          [-1, 512, 20, 30]       2,359,808
             ReLU-39          [-1, 512, 20, 30]               0
      BatchNorm2d-40          [-1, 512, 20, 30]           1,024
  ConvTranspose2d-41          [-1, 256, 40, 60]       1,179,904
             ReLU-42          [-1, 256, 40, 60]               0
      BatchNorm2d-43          [-1, 256, 40, 60]             512
  ConvTranspose2d-44         [-1, 128, 80, 120]         295,040
             ReLU-45         [-1, 128, 80, 120]               0
      BatchNorm2d-46         [-1, 128, 80, 120]             256
  ConvTranspose2d-47         [-1, 64, 160, 240]          73,792
             ReLU-48         [-1, 64, 160, 240]               0
      BatchNorm2d-49         [-1, 64, 160, 240]             128
  ConvTranspose2d-50         [-1, 32, 320, 480]          18,464
             ReLU-51         [-1, 32, 320, 480]               0
      BatchNorm2d-52         [-1, 32, 320, 480]              64
           Conv2d-53         [-1, 21, 320, 480]             693
================================================================
Total params: 23,954,069
Trainable params: 23,954,069
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.76
Forward/backward pass size (MB): 972.07
Params size (MB): 91.38
Estimated Total Size (MB): 1065.21
----------------------------------------------------------------

三、训练网络

网络训练,这里使用代码来逐行理解:

def train_model(model,criterion,optimizer,traindataloader,valdataloader,num_epochs =25):
"""
model:训练模型,criterion:损失函数,optimizer:优化方法,traindataloader:训练数据集,
valdataloader:验证数据集,num_epochs:训练轮数
"""
        since=time.time()  #当前时间
        best_model_wts =copy.deepcopy(model.state_dizt())#复制模型中的参数信息(从内存复制)
        best_loss =1e10#1x10^10
        train_loss_all=[]  #损失函数总和,随时更新
        train_acc_all=[]  #准确率总和,随时更新
        val_loss_all=[]  
        val_acc_all=[]
        since=time.time()

        for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch,num_epochs-1))
        print('-'*10)
        train_loss =0.0
        train_num=0
        val_loss =0.0
        val_num=0
        #每个epoch包括训练和验证阶段
        model.train() #设置模型为训练模式
        for step,(b_x,b_y) in enumerate(traindataloader):
            optimizer.zero_grad()
            b_x =b_x.float().to(device)
            b_y =b_y.long().to(device)
            out =model(b_x)
            out =F.log_softmax(out,dim=1)
"""
             0是对列做归一化,1是对行做归一化
            F.softmax(x,dim=1) 或者 F.softmax(x,dim=0)
            F.log_softmax作用:
            在softmax的结果上再做多一次log运算
"""
            pre_lab =torch.argmax(out,dim=1)##预测的标签,找每行最大值
            loss =criterion(out,b_y) #计算损失函数值
            loss.backward()
            optimizer.step()
            train_loss +=loss.item() * len(b_y)
            train_num +=len(b_y)
        #计算一个epoch训练后在验证集上的损失
        # 计算一个epoch在训练集上的损失和精度
        train_loss_all.append(train_loss / train_num)
        print('{} Train loss: {:.4f}'.format(epoch, train_loss_all[-1]))
        model.eval()
        for step,(b_x,b_y) in enumerate(valdataloader):
            b_x =b_x.float().to(device)
            b_y =b_y.long().to(device)
            #释放没用的内存
            if hasattr(torch.cuda, 'empty_cache'):
                torch.cuda.empty_cache()
            out =model(b_x)
            out =F.log_softmax(out,dim=1)
            pre_lab =torch.argmax(out,1)##预测的标签
            loss =criterion(out,b_y) #计算损失函数值
            val_loss +=loss.item() * len(b_y)
            val_num +=len(b_y)
        #计算一个epoch在训练集上的损失和精度
        val_loss_all.append(val_loss/val_num)
        print('{}Val Loss:{:.4f}'.format(epoch,val_loss_all[-1]))
        #保存最好的网格参数
        if val_loss_all[-1] <best_loss:
            best_loss =val_loss_all[-1]
            best_model_wts =copy.deepcopy(model.state_dict())
            #每个epoch的花费时间
            time_use =time.time()-since
            print("Train and val complete in {:.0f}m {:.0f}s".format(time_use //60,time_use%60))
        train_process =pd.DataFrame(
        data ={"epoch":range(num_epochs),
              "train_loss_all":train_loss_all,
              "val_loss_all":val_loss_all})
        #输出最好的模型
        model.load_state_dict(best_model_wts)
        return model,train_process

四、模型预测

进行网络训练的,定义损失函数、优化器、学习率

#定义损失函数和优化器
LR =0.0003
criterion =nn.NLLLoss()
optimizer =optim.Adam(fcn8s.parameters(),lr=LR,weight_decay=1e-4)
#对模型进行迭代训练,对所有的数据训练epoch轮
fcn8s,train_process =train_model(
fcn8s,criterion,optimizer,train_loader,val_loader,num_epochs=15)
#保存训练好的网络
torch.save(fcn8s,"D:\\Lee's Net\\fcn8s.pkl")

完整代码:

import os
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"
import copy
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torchsummary import summary
from torch import nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch.optim as optim
from torchvision import transforms
from torchvision.models import vgg19
from torchvision.utils import make_grid

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

torch.cuda.empty_cache()

classess = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
            'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
            'dog', 'horse', 'motorbike', 'person', 'potted plant',
            'sheep', 'sofa', 'train', 'tv/monitor']

colormap = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
            [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0],
            [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128],
            [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0],
            [0, 192, 0], [128, 192, 0], [0, 64, 128]]


# 给定一个标好的图片,将像素值对应的物体类别找出来
def image2label(image, colormap):
    # 将标签转化为每个像素值为一类数据
    cm2lbl = np.zeros(256 ** 3)
    for i, cm in enumerate(colormap):
        cm2lbl[(cm[0] * 256 + cm[1] * 256 + cm[2])] = i
    # 对一张图像进行转换
    image = np.array(image, dtype="int")
    ix = (image[:, :, 0] * 256 + image[:, :, 1] * 256 + image[:, :, 2])
    image2 = cm2lbl[ix]
    return image2


# 随机裁剪图像数据
def rand_crop(data, label, high, width):
    im_width, im_high = data.size
    # 生成随机点的位置
    left = np.random.randint(0, im_width - width)
    top = np.random.randint(0, im_high - high)
    right = left + width
    bottom = top + high
    data = data.crop((left, top, right, bottom))
    label = label.crop((left, top, right, bottom))
    return data, label


# 单组图像的转换操作
def img_transforms(data, label, high, width, colormap):
    data, label = rand_crop(data, label, high, width)
    data_tfs = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])])
    data = data_tfs(data)
    label = torch.from_numpy(image2label(label, colormap))
    return data, label


# 定义需要读出的数据路径的函数
def read_image_path(root="D:\\DataSets\\PascalVOC2012\\VOCdevkit\\VOC2012\\ImageSets\\Segmentation\\train.txt"):
    """"保存指定路径下的所有需要读取的图像文件路径"""
    image = np.loadtxt(root, dtype=str)
    # print("image", image)
    n = len(image)
    data, label = [None] * n, [None] * n
    for i, fname in enumerate(image):
        data[i] = "D:\\DataSets\\PascalVOC2012\\VOCdevkit\\VOC2012\\JPEGImages//%s.jpg" % (fname)
        label[i] = "D:\\DataSets\\PascalVOC2012\\VOCdevkit\\VOC2012\\SegmentationClass//%s.png" % (fname)
    return data, label


# 定义一个MyDataset继承于torch.utils.data.Dataset
class Mydataset(Data.Dataset):
    # 用于读取图像,进行相应的裁剪等
    def __init__(self, data_root, high, width, imtransform, colormap):
        # data_root 数据所对应的文件名,high,width:图像裁剪后的尺寸
        # imtransform:图像预处理操作,colormap:颜色
        self.data_root = data_root
        self.high = high
        self.width = width
        self.imtransform = imtransform
        self.colormap = colormap
        data_list, label_list = read_image_path(root=data_root)
        self.data_list = self._filter(data_list)
        self.label_list = self._filter(label_list)

    def _filter(self, images):
        # 过滤掉图片大于指定图片大小的图片
        return [im for im in images if (Image.open(im).size[1] > high and Image.open(im).size[0] > width)]

    def __getitem__(self, idx):
        img = self.data_list[idx]
        label = self.label_list[idx]
        img = Image.open(img)
        label = Image.open(label).convert('RGB')
        img, label = self.imtransform(img, label, self.high, self.width, self.colormap)
        return img, label

    def __len__(self):
        return len(self.data_list)


# 将标准化后的数据转化为0-1的区间
def inv_normalize_image(data):
    rgb_mean = np.array([0.485, 0.456, 0.406])
    rgb_std = np.array([0.229, 0.224, 0.225])
    data = data.astype('float32') * rgb_std + rgb_mean
    return data.clip(0, 1)


# 从预测的标签转化为图像的操作
def label2_image(prelabel, colormap):
    # 预测的标签转化为图像,针对一个标签图
    h, w = prelabel.shape
    prelabel = prelabel.reshape(h * w, -1)
    image = np.zeros((h * w, 3), dtype="int32")
    for ii in range(len(colormap)):
        index = np.where(prelabel == ii)
        image[index, :] = colormap[ii]
    return image.reshape(h, w, 3)


class FCN8s(nn.Module):
    def __init__(self, num_classes):
        super(FCN8s, self).__init__()
        # num_classes 训练数据的类别
        self.num_classes = num_classes

        # 使用预训练好的vgg19网络作为基础网络
        model_vgg19 = vgg19(pretrained=True)
        # 不使用vgg19网络中的后面的adaptiveavgpool2d和linear层
        self.base_model = model_vgg19.features
        # 定义需要的额几个层操作
        self.relu = nn.ReLU(inplace=True)
        self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn1 = nn.BatchNorm2d(512)
        self.deconv2 = nn.ConvTranspose2d(512, 256, 3, 2, 1, 1, 1)
        self.bn2 = nn.BatchNorm2d(256)
        self.deconv3 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1, 1)
        self.bn3 = nn.BatchNorm2d(128)
        self.deconv4 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1, 1)
        self.bn4 = nn.BatchNorm2d(64)
        self.deconv5 = nn.ConvTranspose2d(64, 32, 3, 2, 1, 1, 1)
        self.bn5 = nn.BatchNorm2d(32)
        self.classifier = nn.ConvTranspose2d(32, num_classes, kernel_size=1)

        # vgg19中maxpool2所在的层
        self.layers = {"4": "max_pool_1", "9": "maxpool_2",
                       "18": "maxpool_3", "27": "maxpool_4",
                       "36": "maxpool_5"}

    def forward(self, x):
        output = {}
        for name, layer in self.base_model._modules.items():
            # 从第一层开始获取图像的特征
            x = layer(x)
            # 如果是layer中指定的特征,那就保存到output中‘
            if name in self.layers:
                output[self.layers[name]] = x
        x5 = output["maxpool_5"]
        x4 = output["maxpool_4"]
        x3 = output["maxpool_3"]

        # 对图像进行相应转置卷积操作,逐渐将图像放大到原来大小
        score = self.relu(self.deconv1(x5))
        score = self.bn1(score + x4)
        score = self.relu(self.deconv2(score))
        score = self.bn2(score + x3)
        score = self.bn3(self.relu(self.deconv3(score)))
        score = self.bn4(self.relu(self.deconv4(score)))
        score = self.bn5(self.relu(self.deconv5(score)))
        score = self.classifier(score)
        return score


def train_model(model, criterion, optimizer, traindataloader, valdataloader, num_epochs=25):
    since = time.time()
    best_models_wts = copy.deepcopy(model.state_dict())
    bestloss = 1e10
    train_loss_all = []
    train_acc_all = []
    val_acc_all = []
    val_loss_all = []
    since = time.time()
    for epoch in range(0, num_epochs):
        print('Epoch{}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        train_loss = 0
        train_num = 0
        val_loss = 0
        val_num = 0
        # 每个epoch包括训练和验证阶段
        model.train()
        for step, (b_x, b_y) in enumerate(traindataloader):
            optimizer.zero_grad()
            b_x = b_x.float().to(device)
            b_y = b_y.long().to(device)
            out = model(b_x)
            out = F.log_softmax(out, dim=1)
            pre_lab = torch.argmax(out, 1)  # 预测的标签
            loss = criterion(out, b_y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * len(b_y)
            train_num += len(b_y)
        # 计算一个epoch在训练集上的损失和精度
        train_loss_all.append(train_loss / train_num)
        print('{} Train loss: {:.4f}'.format(epoch, train_loss_all[-1]))

        # 计算一个epoch在训练后在验证集上的损失和精度
        model.eval()
        for step, (b_x, b_y) in enumerate(valdataloader):
            b_x = b_x.float().to(device)
            b_y = b_y.long().to(device)

            if hasattr(torch.cuda, 'empty_cache'):
                torch.cuda.empty_cache()

            out = model(b_x)
            out = F.log_softmax(out, dim=1)
            pre_lab = torch.argmax(out, 1)
            loss = criterion(out, b_y)
            val_loss += loss.item() * len(b_y)
            val_num += len(b_y)

        # 计算一个epoc在验证集上的损失和精度
        val_loss_all.append(val_loss / val_num)
        print('{} Val Loss:{:.4f}'.format(epoch, val_loss_all[-1]))
        # 保存最好的网络参数
        if val_loss_all[-1] < bestloss:
            bestloss = val_loss_all[-1]
            best_models_wts = copy.deepcopy(model.state_dict())
        # 每个epoch的花费时间
        time_use = time.time() - since
        print("Train and Val complete in {:.0f}m {:.0f}s".format(time_use // 60, time_use % 60))
    train_process = pd.DataFrame(
        data={"epoch": range(num_epochs),
              "train_loss_all": train_loss_all,
              "val_loss_all": val_loss_all})
    # 输出最好的模型
    model.load_state_dict(best_models_wts)
    return model, train_process


if __name__ == "__main__":
    high, width = 320, 480
    voc_train = Mydataset("D:\\DataSets\\PascalVOC2012\\VOCdevkit\\VOC2012\\ImageSets\\Segmentation\\train.txt", high, width,
                          img_transforms, colormap)
    voc_val = Mydataset("D:\\DataSets\\PascalVOC2012\\VOCdevkit\\VOC2012\\ImageSets\\Segmentation\\val.txt", high, width,
                        img_transforms, colormap)

    # 创建数据加载器,每个batch使用4张图像
    train_loader = Data.DataLoader(voc_train, batch_size=2, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = Data.DataLoader(voc_val, batch_size=2, shuffle=True, num_workers=0, pin_memory=True)

    # 可视化一个batch的数据
    # for step, (b_x, b_y) in enumerate(train_loader):
    #     if step > 0:
    #         break
    #     # 输出训练图像的尺寸和标签的尺寸,以及数据类型
    #     print("b_x.shape:", b_x.shape)
    #     print("b_y.shape:", b_y.shape)
    #
    #     b_x_numpy = b_x.data.numpy()
    #     b_x_numpy = b_x_numpy.transpose(0,2,3,1)
    #     b_y_numpy = b_y.data.numpy()
    #     plt.figure(figsize=(16,6))
    #
    #     for ii in range(4):
    #         plt.subplot(2,4,ii+1)
    #         plt.imshow(inv_normalize_image(b_x_numpy[ii]))
    #         plt.axis("off")
    #         plt.subplot(2,4,ii+5)
    #         plt.imshow(label2_image(b_y_numpy[ii],colormap))
    #         plt.axis("off")
    #     plt.subplots_adjust(wspace=0.1,hspace=0.1)
    #     plt.show()

    fcn8s = FCN8s(21).to(device)
    summary(fcn8s, input_size=(3, high, width))
    LR = 0.0003
    criterion = nn.NLLLoss()
    optimizer = optim.Adam(fcn8s.parameters(), lr=LR, weight_decay=1e-4)
    # d对模型进行训练,对所有的数据训练epoch轮
    fcn8s, train_process = train_model(
        fcn8s, criterion, optimizer, train_loader,
        val_loader, num_epochs=1
    )
    torch.save(fcn8s, "D:\\Lee's Net\\fcn8s.pkl")
    plt.figure(figsize=(10, 6))
    plt.plot(train_process.epoch, train_process.train_loss_all,
             "ro-", label="Train loss")
    plt.plot(train_process.epoch, train_process.val_loss_all,
             "bs-", label="Val loss")
    plt.legend()
    plt.xlabel("epoch")
    plt.ylabel("Loss")
    plt.show()  # 训练及验证到此结束
#从验证集中获取一个batch的数据
    for step, (b_x, b_y) in enumerate(val_loader):
        if step > 0:
            break
        fcn8s.eval()
        b_x = b_x.float().to(device)
        b_y = b_y.long().to(device)
        out = fcn8s(b_x)
        out = F.log_softmax(out, dim=1)
        pre_lab = torch.argmax(out, 1)
        # 可视化一个batch的图像,检查数据预处理是否正确
        b_x_numpy = b_x.cpu().data.numpy()
        b_x_numpy = b_x_numpy.transpose(0, 2, 3, 1)
        b_y_numpy = b_y.cpu().data.numpy()
        pre_lab_numpy = pre_lab.cpu().data.numpy()
        plt.figure(figsize=(16, 9))
        for ii in range(4):
            plt.subplot(3, 4, ii + 1)
            plt.imshow(inv_normalize_image(b_x_numpy[ii]))
            plt.axis("off")
            plt.subplot(3, 4, ii + 5)
            plt.imshow(label2_image(b_y_numpy[ii], colormap))
            plt.axis("off")
            plt.subplot(3,4, ii + 9)
            plt.imshow(label2_image(pre_lab_numpy[ii], colormap))
            plt.axis("off")
        plt.subplots_adjust(wspace=0.05, hspace=0.05)
        plt.show()

  • 4
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值