基于FCN实现语义分割(pytorch版本)

  data_prepare_funtion.py

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from skimage.io import imread
import torch
import torch.utils.data as Data
from torchvision import transforms
import torch

def read_image(root = "./data/VOC2012/ImageSets/Segmentation/train.txt"):
    """读取指定路径下的所指定的图像文件"""
    image = np.loadtxt(root,dtype=str)
    n = len(image)
    data, label = [None]*n , [None]*n
    for i, fname in enumerate(image):
        data[i] = imread("./data/VOC2012/JPEGImages/%s.jpg" %(fname))
        label[i] = imread("./data/VOC2012/SegmentationClass/%s.png" %(fname))
    return data,label

## 给定一个标号图片,将像素值对应的物体找出来
def image2label(image,colormap):
    ## 将标签转化为没个像素值为1类数据
    cm2lbl = np.zeros(256**3)
    for i,cm in enumerate(colormap):
        cm2lbl[((cm[0]*256+cm[1])*256+cm[2])] = i
    ## 对一张图像准换
    image = np.array(image, dtype="int64")
    ix = (image[:,:,0]*256+image[:,:,1]*256+image[:,:,2]) 
    # 被乘的是一个整体,且是数字,怎么乘都不改变数组的维度和形状
    # image[:,:,0]是把数组的最后一维的第0个通道单独提取了出来,所以改变了数组的维度
    image2 = cm2lbl[ix] # 数组中的每一个元素(数字),均被当作索引
    return image2

def center_crop(data, label, height, width):
    """data, label都是PIL.Image读取的图像"""
    ##使用中心裁剪(因为图像大小是一样的)
    data = transforms.CenterCrop((height, width))(data)
    label = transforms.CenterCrop((height, width))(label)
    return data, label

## 随机裁剪图像数据
def rand_crop(data,label,high,width):
    im_width,im_high = data.size
    ## 生成图像随机点的位置
    left = np.random.randint(0,im_width - width)
    top = np.random.randint(0,im_high - high)
    right = left+width
    bottom = top+high
    data = data.crop((left, top, right, bottom)) 
    label = label.crop((left, top, right, bottom)) 
    return data,label

## 单个图像的转换操作
def img_transforms(data, label, high,width,colormap):
    data, label = rand_crop(data, label, high, width)
    data_tfs = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])])
    data = data_tfs(data)
    label = torch.from_numpy(image2label(label,colormap))
    return data, label

## 定义一列出需要读取的数据路径的函数
def read_image_path(root = "./data/VOC2012/ImageSets/Segmentation/train.txt"):
    """保存指定路径下的所有需要读取的图像文件路径"""
    image = np.loadtxt(root,dtype=str)
    n = len(image)
    data, label = [None]*n , [None]*n
    for i, fname in enumerate(image):
        data[i] = "./data/VOC2012/JPEGImages/%s.jpg" %(fname)
        label[i] = "./data/VOC2012/SegmentationClass/%s.png" %(fname)
    return data,label## 给定一个标号图片,将像素值对应的物体找出来
def image2label(image,colormap):
    ## 将标签转化为没个像素值为1类数据
    cm2lbl = np.zeros(256**3)
    for i,cm in enumerate(colormap):
        cm2lbl[((cm[0]*256+cm[1])*256+cm[2])] = i
    ## 对一张图像准换
    image = np.array(image, dtype="int64")
    ix = ((image[:,:,0]*256+image[:,:,1])*256+image[:,:,2]) 
    # 被乘的是一个整体,且是数字,怎么乘都不改变数组的维度和形状
    # image[:,:,0]是把数组的最后一维的第0个通道单独提取了出来,所以改变了数组的维度
    image2 = cm2lbl[ix] # 数组中的每一个元素(数字),均被当作索引
    return image2

def center_crop(data, label, height, width):
    """data, label都是PIL.Image读取的图像"""
    ##使用中心裁剪(因为图像大小是一样的)
    data = transforms.CenterCrop((height, width))(data)
    label = transforms.CenterCrop((height, width))(label)
    return data, label

## 随机裁剪图像数据
def rand_crop(data,label,high,width):
    im_width,im_high = data.size
    ## 生成图像随机点的位置
    left = np.random.randint(0,im_width - width)
    top = np.random.randint(0,im_high - high)
    right = left+width
    bottom = top+high
    data = data.crop((left, top, right, bottom)) 
    label = label.crop((left, top, right, bottom)) 
    return data,label

## 单个图像的转换操作
def img_transforms(data, label, high,width,colormap):
    data, label = rand_crop(data, label, high, width)
    data_tfs = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])])
    data = data_tfs(data)
    label = torch.from_numpy(image2label(label,colormap))
    return data, label

## 定义一列出需要读取的数据路径的函数
def read_image_path(root = "./data/VOC2012/ImageSets/Segmentation/train.txt"):
    """保存指定路径下的所有需要读取的图像文件路径"""
    image = np.loadtxt(root,dtype=str)
    n = len(image)
    data, label = [None]*n , [None]*n
    for i, fname in enumerate(image):
        data[i] = "./data/VOC2012/JPEGImages/%s.jpg" %(fname)
        label[i] = "./data/VOC2012/SegmentationClass/%s.png" %(fname)
    return data,label

def batch_visualization (b_x,b_y,colormap):
    ## 输出训练图像的尺寸和标签的尺寸,和数据类型
    print("b_x.shape:",b_x.shape)
    print("b_y.shape:",b_y.shape)
    print("b_x.dtype:",b_x.dtype)
    print("b_y.dtype:",b_y.dtype)

    ## 可视化一个batch的图像,检查数据预处理 是否正确
    b_x_numpy = b_x.data.numpy()
    b_x_numpy = b_x_numpy.transpose(0,2,3,1)
    b_y_numpy = b_y.data.numpy()
    plt.figure(figsize=(16,6))
    for ii in range(4):
        plt.subplot(2,4,ii+1)
        plt.imshow(inv_normalize_image(b_x_numpy[ii]))
        plt.axis("off")
        plt.subplot(2,4,ii+5)
        plt.imshow(label2image(b_y_numpy[ii],colormap))
        plt.axis("off")
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()

def inv_normalize_image(data):
    rgb_mean = np.array([0.485, 0.456, 0.406])
    rgb_std = np.array([0.229, 0.224, 0.225])
    data = data.astype('float32') * rgb_std + rgb_mean
    return data.clip(0,1)

## 从预测的标签转化为图像的操作
def label2image(prelabel,colormap):
    ## 预测的到的标签转化为图像,针对一个标签图
    h,w = prelabel.shape
    prelabel = prelabel.reshape(h*w,-1)
    image = np.zeros((h*w,3),dtype="int32")
    for ii in range(len(colormap)):
        index = np.where(prelabel == ii) # 返回数组中特定的索引
        image[index,:] = colormap[ii]
    return image.reshape(h,w,3)
    
## 最后我们定义一个 MyDataset 继承于torch.utils.data.Dataset构成我们自定的训练集
class MyDataset(Data.Dataset):
    """用于读取图像,并进行相应的裁剪等"""
    def __init__(self, data_root,high,width, imtransform,colormap):
        ## data_root:数据所对应的文件名,high,width:图像裁剪后的尺寸,
        ## imtransform:预处理操作,colormap:颜色
        self.data_root = data_root
        self.high = high
        self.width = width
        self.imtransform = imtransform
        self.colormap = colormap
        data_list, label_list = read_image_path(root=data_root)
        self.data_list = self._filter(data_list)
        self.label_list = self._filter(label_list)
        
    def _filter(self, images): # 过滤掉图片大小小于指定high,width的图片
        return [im for im in images if (Image.open(im).size[1] > self.high and 
                                        Image.open(im).size[0] > self.width)]
    
    def __getitem__(self, idx):
        img = self.data_list[idx]
        label = self.label_list[idx]
        img = Image.open(img)
        label = Image.open(label).convert('RGB')
        img, label = self.imtransform(img, label, self.high,
                                      self.width,self.colormap)
        return img, label
    
    def __len__(self):
        return len(self.data_list)

EarlyStopping.py

import numpy as np
import torch


class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=7, verbose=True, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(
                f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
        return self.early_stop

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')

        torch.save(model, "./data/trained/fcn8s_trained.pkl")
        self.val_loss_min = val_loss

 

my_fcn_net.py

import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import vgg19
import hiddenlayer as hl
from EarlyStopping import EarlyStopping

## vgg19的features网络通过5个MaxPool将图像尺寸缩小了32倍
## 图像尺寸缩小后分别在:MaxPool2d-5(缩小2倍) ,MaxPool2d-10 (缩小4倍),MaxPool2d-19(缩小8倍),
## MaxPool2d-28(缩小16倍),MaxPool2d-37(缩小32倍)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
early_stopping = EarlyStopping(patience=15, verbose=True)

## 定义FCN语义分割网络
class FCN8s(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # num_classes:训练数据的类别
        self.num_classes = num_classes
        model_vgg19 = vgg19(pretrained=True)
        self.base_model = model_vgg19.features

        ## 定义几个需要的层操作,并且使用转置卷积将特征映射进行升维
        self.relu    = nn.ReLU(inplace=True)
        self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2,
                                          padding=1, dilation=1, output_padding=1)
        self.bn1     = nn.BatchNorm2d(512)
        self.deconv2 = nn.ConvTranspose2d(512, 256, 3, 2, 1, 1, 1)
        self.bn2     = nn.BatchNorm2d(256)
        self.deconv3 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1, 1)
        self.bn3     = nn.BatchNorm2d(128)
        self.deconv4 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1, 1)
        self.bn4     = nn.BatchNorm2d(64)
        self.deconv5 = nn.ConvTranspose2d(64, 32, 3, 2, 1, 1, 1)
        self.bn5     = nn.BatchNorm2d(32)
        self.classifier = nn.Conv2d(32, num_classes, kernel_size=1)
    
        ## vgg19中MaxPool2d所在的层
        self.layers = {"4" : "maxpool_1",
                       "9" : "maxpool_2", 
                       "18": "maxpool_3", 
                       "27": "maxpool_4", 
                       "36": "maxpool_5"}

    def forward(self, x):
        output = {}
        for name, layer in self.base_model._modules.items():
            ## 从第一层开始获取图像的特征
            x = layer(x)

            ## 如果是layers参数指定的特征,那就保存到output中
            if name in self.layers:
                output[self.layers[name]] = x

        x5 = output["maxpool_5"]  # size=(N, 512, x.H/32, x.W/32)
        x4 = output["maxpool_4"]  # size=(N, 512, x.H/16, x.W/16)
        x3 = output["maxpool_3"]  # size=(N, 256, x.H/8,  x.W/8)

        ## 对特征进行相关的转置卷积操作,逐渐将图像放大到原始图像大小
       
        score = self.relu(self.deconv1(x5))  # size=(N, 512, x.H/16, x.W/16)
        score = self.bn1(score + x4)  # 对应的元素相加, size=(N, 512, x.H/16, x.W/16)
        score = self.relu(self.deconv2(score))  # size=(N, 256, x.H/8, x.W/8)
        score = self.bn2(score + x3)  # 对应的元素相加, size=(N, 256, x.H/8, x.W/8)
        score = self.bn3(self.relu(self.deconv3(score)))  # size=(N, 128, x.H/4, x.W/4) 
        score = self.bn4(self.relu(self.deconv4(score)))  # size=(N, 64, x.H/2, x.W/2)
        score = self.bn5(self.relu(self.deconv5(score)))   # size=(N, 32, x.H, x.W)  
        score = self.classifier(score)  # 最后一层是卷积,把输出变成分类的维数                 
        return score  # size=(N, n_class, x.H/1, x.W/1)


def train_model(model, criterion, optimizer,traindataloader,
                valdataloader, num_epochs=25):

    """
    model:网络模型;criterion:损失函数;optimizer:优化方法;
    traindataloader:训练数据集,valdataloader:验证数据集
    num_epochs:训练的轮数
    """
    history1 = hl.History()
    # 使用Canvas进行可视化
    canvas1 = hl.Canvas()

    for epoch in range(num_epochs):

        train_loss = 0.0
        train_num = 0
        val_loss = 0.0
        val_num = 0

        # 每个epoch包括训练和验证阶段
        model.train() ## 设置模型为训练模式
        for step,(b_x,b_y) in enumerate(traindataloader):
            optimizer.zero_grad() 
            b_x  =b_x.float().to(device)
            b_y  =b_y.long().to(device)
            out = model(b_x)
            out = F.log_softmax(out,dim=1)
            pre_lab = torch.argmax(out,1) # 预测的标签
            loss = criterion(out, b_y) # 计算损失函数值       
            loss.backward()       
            optimizer.step()  
            train_loss += loss.item() * len(b_y)
            train_num += len(b_y)
            
        ## 计算一个epoch在训练集上的损失和精度
        train_loss_all = train_loss / train_num
        
        # 计算一个epoch的训练后在验证集上的损失
        model.eval() ## 设置模型为训练模式评估模式 
        for step,(b_x,b_y)  in enumerate(valdataloader):
            b_x  =b_x.float().to(device)
            b_y  =b_y.long().to(device)
            out = model(b_x)
            out = F.log_softmax(out,dim=1)
            pre_lab = torch.argmax(out,1)
            loss = criterion(out, b_y)   
            val_loss += loss.item() * len(b_y)
            val_num += len(b_y)
        ## 计算一个epoch在训练集上的损失和精度
        val_loss_all = val_loss / val_num

        print(val_loss_all)

        history1.log(epoch, train_loss=train_loss_all,
                            val_loss=val_loss_all
                 )
        # 可视网络训练的过程
        with canvas1:
            canvas1.draw_plot(history1["train_loss"])
            canvas1.draw_plot(history1["val_loss"])

        if early_stopping(val_loss_all, model) is True:
            break

train.py

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import glob
from time import time
import os
from skimage.io import imread
import copy
import time

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torch.utils.data as Data
from torchvision import transforms
from torchvision.models import vgg19
from data_prepare_funtion import *
from my_fcn_net import *

## 定义计算设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# ## 读取训练数据,验证数据
# traindata,trainlabel = read_image(root = "./data/VOC2012/ImageSets/Segmentation/train.txt")
# valdata,vallabel = read_image(root = "./data/VOC2012/ImageSets/Segmentation/val.txt")
# print ('len_train: ', len(traindata))
# print ('len_val: ', len(valdata))

## 列出每个物体对应背景的RGB值
classes = ['background','aeroplane','bicycle','bird','boat',
           'bottle','bus','car','cat','chair','cow','diningtable',
           'dog','horse','motorbike','person','potted plant',
           'sheep','sofa','train','tv/monitor']
# 每个类的RGB值
colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]

## 读取数据
high,width = 320,480
voc_train = MyDataset("./data/VOC2012/ImageSets/Segmentation/train.txt",
                      high,width, img_transforms,colormap)
voc_val = MyDataset("./data/VOC2012/ImageSets/Segmentation/val.txt",
                    high,width, img_transforms,colormap)  

# 创建数据加载器每个batch使用4张图像
train_loader = Data.DataLoader(voc_train, batch_size=4,shuffle=True,
                               num_workers=16,pin_memory=True)
val_loader = Data.DataLoader(voc_val, batch_size=4,shuffle=True,
                             num_workers=16,pin_memory=True)

#  检查训练数据集的一个batch的样本的维度是否正确
for step, (b_x, b_y) in enumerate(train_loader):  
    if step > 0:
        break
    batch_visualization(b_x,b_y,colormap)

## 注意输入图像的尺寸应该是32的整数倍
class_number=21
fcn8s = FCN8s(class_number).to(device)

## 网络的训练和预测
## 网络的训练函数

## 定义损失函数和优化器
LR = 0.0003
criterion = nn.NLLLoss()
optimizer = optim.Adam(fcn8s.parameters(), lr=LR,weight_decay=1e-4)

## 对模型进行迭代训练,对所有的数据训练EPOCH轮
train_model(fcn8s,criterion, optimizer,
            train_loader,val_loader,
            num_epochs=1000)

## 对验证集中一个batch的数据进行预测,并可视化预测效果
fcn8s = torch.load("/home/xcy/torch/data/trained/fcn8s_trained.pkl")
fcn8s.to(device)

for step, (b_x, b_y) in enumerate(val_loader):  
    if step > 0:
        break

fcn8s.eval()
b_x  =b_x.float().to(device)
b_y  =b_y.long().to(device)
out = fcn8s(b_x)
out = F.log_softmax(out,dim=1)
pre_lab = torch.argmax(out,1)
## 可视化一个batch的图像,检查数据预处理 是否正确
b_x_numpy = b_x.cpu().data.numpy()
b_x_numpy = b_x_numpy.transpose(0,2,3,1)
b_y_numpy = b_y.cpu().data.numpy()
pre_lab_numpy = pre_lab.cpu().data.numpy()
plt.figure(figsize=(16,9))
for ii in range(4):
    plt.subplot(3,4,ii+1)
    plt.imshow(inv_normalize_image(b_x_numpy[ii]))
    plt.axis("off")
    plt.subplot(3,4,ii+5)
    plt.imshow(label2image(b_y_numpy[ii],colormap))
    plt.axis("off")
    plt.subplot(3,4,ii+9)
    plt.imshow(label2image(pre_lab_numpy[ii],colormap))
    plt.axis("off")
plt.subplots_adjust(wspace=0.05, hspace=0.05)
plt.show()

 

  • 0
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dylan_zhang7788

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值