语义分割示例—FCN识别路面灌缝区域(3)训练与测试

本文详细介绍了使用PyTorch实现的FCN模型训练过程,包括参数设置、数据预处理、模型结构和训练验证流程。通过xj0parameters.py配置训练参数,xj1ModelFCN.py定义了模型,xj2ImageDataset.py负责数据处理,xj3train.py进行模型训练,最后在xj4Predict.py中展示了测试阶段的预测操作。
摘要由CSDN通过智能技术生成

笔者的训练代码总共有4个:(1)xj0paremeters.py用于存放参数;(2)xj1ModelFCN.py是FCN模型;(3)xj2ImageDataset.py用于数据处理;(4)xj3train.py用于训练。

一、参数

xj0paremeters.py

import torch.cuda

# 训练数据
dirTrainImage = 'E:/py/dataCrack/00AA1originalNote'
dirTrainLabel = 'E:/py/dataCrack/00CC32label_2gf'
ratioTrainVal = 0.9

#region 训练参数
# 图像预处理:缩放
scale = 2
resizeRow = 512 * scale
resizeCol = 256 * scale

# 训练参数
device = ('cuda' if torch.cuda.is_available() else 'cpu')
batchSize = 1
numWorkers = 1
pEpochs = 501
pLearningRate = 1e-3
pMomentum = 0.7
#endregion

# 测试文件路径
dirPreImage = 'E:/py/dataCrack/test'
dirPreResult = 'E:/py/dataCrack/testRes'

二、模型

xj1ModelFCN.py
见:语义分割示例—FCN识别路面灌缝区域(2)FCN

三、数据处理

源代码和注释见:pytorch用FCN语义分割手提包数据集(训练+预测单张输入图片代码)
xj2ImageDataset.py

# coding=utf-8
import os,cv2
import numpy as np
import torch
import torchvision
import warnings
warnings.filterwarnings("ignore")

from xj0paremeters import dirTrainImage,dirTrainLabel,ratioTrainVal
from xj0paremeters import scale,resizeRow,resizeCol
from xj0paremeters import batchSize,numWorkers


# torchvision process the image
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485,0.456,0.406],
                                     std =[0.229,0.224,0.225])
])

# 将 label 转为一维向量:label 是以灰度图形式读取;n 是分割的类别数量
def onehot(label, n):
    buf = np.zeros(label.shape+(n,))
    n_msk = np.arange(label.size) * n + label.ravel()
    buf.ravel()[n_msk-1] = 1
    return buf

# process the image: read, resize, toTensor
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, transform=None):
        self.transform = transform

    def __len__(self):
        return len(os.listdir(dirTrainImage))

    def __getitem__(self, item):
        imageName = os.listdir(dirTrainImage)[item]
        imageA = cv2.imread(dirTrainImage + '/' +imageName)
        if scale != 0:
            imageA = cv2.resize(imageA,(resizeCol,resizeRow))

        baseName,extension = os.path.splitext(imageName)

        labelPath = dirTrainLabel + '/' +baseName + '.png'
        imageB = cv2.imread(labelPath, 0)
        if scale != 0:
            imageB = cv2.resize(imageB,(resizeCol,resizeRow))

        imageB = imageB / 255
        imageB = imageB.astype('uint8')
        imageB = onehot(imageB,2)
        imageB = imageB.transpose(2,0,1)
        imageB = torch.FloatTensor(imageB)

        if self.transform:
            imageA = self.transform(imageA)

        return imageA, imageB

#
pouring = ImageDataset(transform)
trainSize = int(ratioTrainVal*len(pouring))
valSize = len(pouring) - trainSize
trainDataset, valDataset = torch.utils.data.random_split(pouring,[trainSize,valSize])
trainLoader = torch.utils.data.DataLoader(trainDataset,batch_size=batchSize,shuffle=True,num_workers=numWorkers)
valLoader = torch.utils.data.DataLoader(valDataset,batch_size=batchSize,shuffle=True,num_workers=numWorkers)

四、训练

xj3train.py

import time
import numpy as np
import torch
import visdom
import warnings
warnings.filterwarnings('ignore')

from xj0paremeters import pEpochs,pLearningRate,pMomentum
from xj0paremeters import device
from xj1ModelFCN import FCNs,VGG_Net
from xj2ImageDataset import trainLoader, valLoader


def train(epoch):
    timeStart = time.time()
    minTrainLoss=999999
    trainLoss = 0
    fcn_model.train()
    for index,(gf, gf_label) in enumerate(trainLoader):
        gf = gf.to(device)
        gf_label = gf_label.to(device)
        # gf.shape is torch.Size([4, 3, rsize, rsize])
        # gf_label.shape is torch.Size([4, 2, rsize, rsize])

        optimizer.zero_grad()
        output = fcn_model(gf)
        output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, rsize, rsize])

        loss = criterion(output, gf_label)
        loss.backward()
        iterLoss = loss.item()
        trainLoss += iterLoss
        optimizer.step()

        if index % 100 ==0:
            print('[{}/{}]\tloss={:.4f}'.format(index,len(trainLoader),iterLoss))

    trainLoss /= len(trainLoader)
    timeEnd = time.time()
    print('train loss=%.4f, time=%.2f' % (trainLoss,(timeEnd-timeStart)))
    visdomTL.line([trainLoss], [epoch],win=visdomTLwindow,update='append')

    if trainLoss<minTrainLoss:
        torch.save(fcn_model,'model/minTrainLoss_fcn_model.pth')

    if (np.mod(epoch,2) == 0) and (trainLoss < 0.05):
        torch.save(fcn_model,'model/fcn_model_{}.pth'.format(epoch))
        torch.save(fcn_model,'model/fcn_model_{}.pt'.format(epoch))
        print('saved the model to model/fcn_model_{}.pth'.format(epoch))
    
    return trainLoss



def validate(epoch):
    timeStart = time.time()
    minValLoss=999999
    valLoss = 0
    fcn_model.eval()
    with torch.no_grad():
        for index, (gf,gf_label) in enumerate(valLoader):
            gf = gf.to(device)
            gf_label = gf_label.to(device)
            optimizer.zero_grad()
            output = fcn_model(gf)
            output = torch.sigmoid(output)
            loss = criterion(output,gf_label)
            iterLoss = loss.item()
            valLoss += iterLoss

    valLoss /= len(valLoader)
    timeEnd = time.time()
    print('val loss=%f, time=%.2f' % (valLoss,(timeEnd-timeStart)))
    visdomVL.line([valLoss], [epoch],win=visdomVLwindow,update='append')

    if valLoss<minValLoss:
        torch.save(fcn_model,'model/minValLoss_fcn_model.pth')
    
    return valLoss




if __name__ == "__main__":
	# python -m visdom.server
    visdomTL = visdom.Visdom()
    visdomTLwindow = visdomTL.line([0],[0],opts=dict(title='train_loss'))
    visdomVL = visdom.Visdom()
    visdomVLwindow = visdomVL.line([0],[0],opts=dict(title='validate_loss'))
    visdomTVL = visdom.Visdom(env='FCN')

    #model
    vgg_model = VGG_Net(requires_grad=True,show_params=False)
    fcn_model = FCNs(pretrained_net=vgg_model,n_class=2).to(device)
    criterion = torch.nn.BCELoss().to(device)
    optimizer = torch.optim.SGD(fcn_model.parameters(),lr=pLearningRate,momentum=pMomentum)

    for epoch in range(pEpochs):
        print('epoch=',epoch,'----------------------')
        trainLoss = train(epoch)
        validateLoss = validate(epoch)
        visdomTVL.line(X=[epoch],Y=[[trainLoss],[validateLoss]],
                     win='TVL',update='append',
                     opts=dict(showlegend=True,
                               markers=False,
                               title='FCN train validate loss',
                               xlabel='epoch',ylabel='loss',
                               legend=["train loss", "validate loss"]))

在这里插入图片描述

五、测试

既然训练完了,那就(开盲盒)测试一下吧。
xj4Predict.py

import numpy as np
import cv2,time,os,glob
import torch,torchvision
from torch.multiprocessing import Pool

from xj0paremeters import device
from xj0paremeters import scale,resizeRow,resizeCol
from xj0paremeters import dirPreImage,dirPreResult


# torchvision process the image
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485,0.456,0.406],
                                     std =[0.229,0.224,0.225])
])


def PredictImage(path):
    imageA = cv2.imread(path)
    imageWidth = imageA.shape[1]
    imageHeight = imageA.shape[0]

    if scale!=0:
        imageA =cv2.resize(imageA,(resizeCol,resizeRow))

    imageA = transform(imageA).to(device)
    imageA = imageA.unsqueeze(0)

    output = model(imageA)
    output = torch.sigmoid(output).squeeze().cpu().detach().numpy().round()
    output = np.argmin(output, axis=0)

    pre = cv2.resize((output*255).astype('uint8'),(imageWidth,imageHeight))
    resultPath = path.replace(dirPreImage,dirPreResult).replace('.jpg','.png')
    cv2.imwrite(resultPath,pre)
    print(resultPath)


if __name__ == '__main__':
    model = torch.load('E:/py/FCN/model/minTrainLoss_fcn_model.pth')
    # model = torch.load('E:/py/FCN/model/minValLoss_fcn_model.pth')
    model = model.to(device)
    model.eval()

    if not os.path.isdir(dirPreResult):
        os.makedirs(dirPreResult)

    timeStart = time.time()

    listFile = os.listdir(dirPreImage)
    for i in range(0,len(listFile)):
        path = os.path.join(dirPreImage+'/',listFile[i])
        PredictImage(path)

    timeEnd = time.time()
    print('处理完成时间=%.2f秒' % (timeEnd-timeStart))

在这里插入图片描述
在这里插入图片描述
结果不大好…

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

累了就要打游戏

把我养胖,搞代码

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值