笔者的训练代码总共有4个:(1)xj0paremeters.py用于存放参数;(2)xj1ModelFCN.py是FCN模型;(3)xj2ImageDataset.py用于数据处理;(4)xj3train.py用于训练。
一、参数
xj0paremeters.py
import torch.cuda
# 训练数据
dirTrainImage = 'E:/py/dataCrack/00AA1originalNote'
dirTrainLabel = 'E:/py/dataCrack/00CC32label_2gf'
ratioTrainVal = 0.9
#region 训练参数
# 图像预处理:缩放
scale = 2
resizeRow = 512 * scale
resizeCol = 256 * scale
# 训练参数
device = ('cuda' if torch.cuda.is_available() else 'cpu')
batchSize = 1
numWorkers = 1
pEpochs = 501
pLearningRate = 1e-3
pMomentum = 0.7
#endregion
# 测试文件路径
dirPreImage = 'E:/py/dataCrack/test'
dirPreResult = 'E:/py/dataCrack/testRes'
二、模型
xj1ModelFCN.py
见:语义分割示例—FCN识别路面灌缝区域(2)FCN
三、数据处理
源代码和注释见:pytorch用FCN语义分割手提包数据集(训练+预测单张输入图片代码)
xj2ImageDataset.py
# coding=utf-8
import os,cv2
import numpy as np
import torch
import torchvision
import warnings
warnings.filterwarnings("ignore")
from xj0paremeters import dirTrainImage,dirTrainLabel,ratioTrainVal
from xj0paremeters import scale,resizeRow,resizeCol
from xj0paremeters import batchSize,numWorkers
# torchvision process the image
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485,0.456,0.406],
std =[0.229,0.224,0.225])
])
# 将 label 转为一维向量:label 是以灰度图形式读取;n 是分割的类别数量
def onehot(label, n):
buf = np.zeros(label.shape+(n,))
n_msk = np.arange(label.size) * n + label.ravel()
buf.ravel()[n_msk-1] = 1
return buf
# process the image: read, resize, toTensor
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, transform=None):
self.transform = transform
def __len__(self):
return len(os.listdir(dirTrainImage))
def __getitem__(self, item):
imageName = os.listdir(dirTrainImage)[item]
imageA = cv2.imread(dirTrainImage + '/' +imageName)
if scale != 0:
imageA = cv2.resize(imageA,(resizeCol,resizeRow))
baseName,extension = os.path.splitext(imageName)
labelPath = dirTrainLabel + '/' +baseName + '.png'
imageB = cv2.imread(labelPath, 0)
if scale != 0:
imageB = cv2.resize(imageB,(resizeCol,resizeRow))
imageB = imageB / 255
imageB = imageB.astype('uint8')
imageB = onehot(imageB,2)
imageB = imageB.transpose(2,0,1)
imageB = torch.FloatTensor(imageB)
if self.transform:
imageA = self.transform(imageA)
return imageA, imageB
#
pouring = ImageDataset(transform)
trainSize = int(ratioTrainVal*len(pouring))
valSize = len(pouring) - trainSize
trainDataset, valDataset = torch.utils.data.random_split(pouring,[trainSize,valSize])
trainLoader = torch.utils.data.DataLoader(trainDataset,batch_size=batchSize,shuffle=True,num_workers=numWorkers)
valLoader = torch.utils.data.DataLoader(valDataset,batch_size=batchSize,shuffle=True,num_workers=numWorkers)
四、训练
xj3train.py
import time
import numpy as np
import torch
import visdom
import warnings
warnings.filterwarnings('ignore')
from xj0paremeters import pEpochs,pLearningRate,pMomentum
from xj0paremeters import device
from xj1ModelFCN import FCNs,VGG_Net
from xj2ImageDataset import trainLoader, valLoader
def train(epoch):
timeStart = time.time()
minTrainLoss=999999
trainLoss = 0
fcn_model.train()
for index,(gf, gf_label) in enumerate(trainLoader):
gf = gf.to(device)
gf_label = gf_label.to(device)
# gf.shape is torch.Size([4, 3, rsize, rsize])
# gf_label.shape is torch.Size([4, 2, rsize, rsize])
optimizer.zero_grad()
output = fcn_model(gf)
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, rsize, rsize])
loss = criterion(output, gf_label)
loss.backward()
iterLoss = loss.item()
trainLoss += iterLoss
optimizer.step()
if index % 100 ==0:
print('[{}/{}]\tloss={:.4f}'.format(index,len(trainLoader),iterLoss))
trainLoss /= len(trainLoader)
timeEnd = time.time()
print('train loss=%.4f, time=%.2f' % (trainLoss,(timeEnd-timeStart)))
visdomTL.line([trainLoss], [epoch],win=visdomTLwindow,update='append')
if trainLoss<minTrainLoss:
torch.save(fcn_model,'model/minTrainLoss_fcn_model.pth')
if (np.mod(epoch,2) == 0) and (trainLoss < 0.05):
torch.save(fcn_model,'model/fcn_model_{}.pth'.format(epoch))
torch.save(fcn_model,'model/fcn_model_{}.pt'.format(epoch))
print('saved the model to model/fcn_model_{}.pth'.format(epoch))
return trainLoss
def validate(epoch):
timeStart = time.time()
minValLoss=999999
valLoss = 0
fcn_model.eval()
with torch.no_grad():
for index, (gf,gf_label) in enumerate(valLoader):
gf = gf.to(device)
gf_label = gf_label.to(device)
optimizer.zero_grad()
output = fcn_model(gf)
output = torch.sigmoid(output)
loss = criterion(output,gf_label)
iterLoss = loss.item()
valLoss += iterLoss
valLoss /= len(valLoader)
timeEnd = time.time()
print('val loss=%f, time=%.2f' % (valLoss,(timeEnd-timeStart)))
visdomVL.line([valLoss], [epoch],win=visdomVLwindow,update='append')
if valLoss<minValLoss:
torch.save(fcn_model,'model/minValLoss_fcn_model.pth')
return valLoss
if __name__ == "__main__":
# python -m visdom.server
visdomTL = visdom.Visdom()
visdomTLwindow = visdomTL.line([0],[0],opts=dict(title='train_loss'))
visdomVL = visdom.Visdom()
visdomVLwindow = visdomVL.line([0],[0],opts=dict(title='validate_loss'))
visdomTVL = visdom.Visdom(env='FCN')
#model
vgg_model = VGG_Net(requires_grad=True,show_params=False)
fcn_model = FCNs(pretrained_net=vgg_model,n_class=2).to(device)
criterion = torch.nn.BCELoss().to(device)
optimizer = torch.optim.SGD(fcn_model.parameters(),lr=pLearningRate,momentum=pMomentum)
for epoch in range(pEpochs):
print('epoch=',epoch,'----------------------')
trainLoss = train(epoch)
validateLoss = validate(epoch)
visdomTVL.line(X=[epoch],Y=[[trainLoss],[validateLoss]],
win='TVL',update='append',
opts=dict(showlegend=True,
markers=False,
title='FCN train validate loss',
xlabel='epoch',ylabel='loss',
legend=["train loss", "validate loss"]))
五、测试
既然训练完了,那就(开盲盒)测试一下吧。
xj4Predict.py
import numpy as np
import cv2,time,os,glob
import torch,torchvision
from torch.multiprocessing import Pool
from xj0paremeters import device
from xj0paremeters import scale,resizeRow,resizeCol
from xj0paremeters import dirPreImage,dirPreResult
# torchvision process the image
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485,0.456,0.406],
std =[0.229,0.224,0.225])
])
def PredictImage(path):
imageA = cv2.imread(path)
imageWidth = imageA.shape[1]
imageHeight = imageA.shape[0]
if scale!=0:
imageA =cv2.resize(imageA,(resizeCol,resizeRow))
imageA = transform(imageA).to(device)
imageA = imageA.unsqueeze(0)
output = model(imageA)
output = torch.sigmoid(output).squeeze().cpu().detach().numpy().round()
output = np.argmin(output, axis=0)
pre = cv2.resize((output*255).astype('uint8'),(imageWidth,imageHeight))
resultPath = path.replace(dirPreImage,dirPreResult).replace('.jpg','.png')
cv2.imwrite(resultPath,pre)
print(resultPath)
if __name__ == '__main__':
model = torch.load('E:/py/FCN/model/minTrainLoss_fcn_model.pth')
# model = torch.load('E:/py/FCN/model/minValLoss_fcn_model.pth')
model = model.to(device)
model.eval()
if not os.path.isdir(dirPreResult):
os.makedirs(dirPreResult)
timeStart = time.time()
listFile = os.listdir(dirPreImage)
for i in range(0,len(listFile)):
path = os.path.join(dirPreImage+'/',listFile[i])
PredictImage(path)
timeEnd = time.time()
print('处理完成时间=%.2f秒' % (timeEnd-timeStart))
结果不大好…