整个工程文件已放到Github上
https://github.com/yaoyi30/PyTorch_Image_Segmentation
一、训练图像分类网络主要流程
- 构建数据集
- 数据预处理、包括数据增强和数据标准化和归一化
- 构建网络模型
- 设置学习率、优化器、损失函数等超参数
- 训练和验证
二、各个流程简要说明
1. 构建数据集
本文使用supervisely 发布的人像分割数据集,百度网盘地址:
https://pan.baidu.com/s/1B8eBqg7XROHOsm5OLw-t9g 提取码: 52ss
在工程目录下,新建datasets文件夹,在文件夹内分别新建images和labels文件夹,用来放图片和对应的mask图片,之后在两个文件夹内新建train和val文件夹用来存放训练和验证数据,结构如下:
datasets/
images/ # images
train/
img1.jpg
img2.jpg
.
.
.
val/
img1.jpg
img2.jpg
.
.
.
labels/ # masks
train/
img1.png
img2.png
.
.
.
val/
img1.png
img2.png
.
.
.
2. 数据预处理
将图像resize到统一大小,之后转为tensor格式再进行标准化,预处理之后的图片可以正常输入网络,对于训练集可以采取一些数据增强手段来增强网络的泛化能力,验证集不做数据增强。
#训练数据预处理、数据增强设置
train_transform = Compose([
Resize(args.input_size), #图像resize到统一大小
RandomHorizontalFlip(0.5), #数据增强,水平翻转
ToTensor(), #转为tensor格式,值变为0-1之间
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #标准化
])
#验证数据预处理
val_transform = Compose([
Resize(args.input_size), #图像resize到统一大小
ToTensor(), #转为tensor格式,值变为0-1之间
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #标准化
])
3. 构建网络模型
本文搭建了一个简单的图像分割网路,命名为Simplify_Net。
model = Simplify_Net(args.nb_classes)
4. 设置学习率、优化器、损失函数等超参数
#定义损失函数,因为分割是像素级的分类,因此可以选用交叉熵损失函数
loss_function = nn.CrossEntropyLoss()
#定义优化器(初始学习率和权重衰减值)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
#定义学习率类型,此处选用余弦退火学习率,设置最大值
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.max_lr, total_steps=args.epochs, verbose=True)
5. 训练和验证
#训练和验证模型,具体函写在了utils/engine.py文件中
history = train_and_val(args.epochs, model, train_loader,val_loader,loss_function, optimizer,scheduler,args.output_dir,device,args.nb_classes)
三、工程代码文件详细讲解
train.py
定义训练的入口函数,以及训练所需要的流程
1. 导入相应的库和文件
import os
import torch
import torch.nn as nn
from models.Simplify_Net import Simplify_Net
from utils.engine import train_and_val,plot_pix_acc,plot_miou,plot_loss,plot_lr
import argparse
import numpy as np
from utils.transform import Resize,Compose,ToTensor,Normalize,RandomHorizontalFlip
from utils.datasets import SegData
2. 训练参数设置
def get_args_parser():
parser = argparse.ArgumentParser('Image Segmentation Train', add_help=False)
parser.add_argument('--batch_size', default=32, type=int,help='Batch size for training')
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
parser.add_argument('--data_path', default='./datasets/', type=str,help='dataset path')
parser.add_argument('--init_lr', default=1e-5, type=float,help='intial lr')
parser.add_argument('--max_lr', default=1e-3, type=float,help='max lr')
parser.add_argument('--weight_decay', default=1e-5, type=float,help='weight decay')
parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')
parser.add_argument('--output_dir', default='./output_dir',help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',help='device to use for training / testing')
parser.add_argument('--num_workers', default=4, type=int)
return parser
3. 定义主函数
def main(args):
device = torch.device(args.device)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
train_transform = Compose([
Resize(args.input_size),
RandomHorizontalFlip(0.5),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = Compose([
Resize(args.input_size),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = SegData(image_path=os.path.join(args.data_path, 'images/train'),
mask_path=os.path.join(args.data_path, 'labels/train'),
data_transforms=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers)
val_dataset = SegData(image_path=os.path.join(args.data_path, 'images/val'),
mask_path=os.path.join(args.data_path, 'labels/val'),
data_transforms=val_transform)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers)
model = Simplify_Net(args.nb_classes)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.max_lr, total_steps=args.epochs, verbose=True)
history = train_and_val(args.epochs, model, train_loader,val_loader,loss_function, optimizer,scheduler,args.output_dir,device,args.nb_classes)
plot_loss(np.arange(0,args.epochs),args.output_dir, history)
plot_pix_acc(np.arange(0,args.epochs),args.output_dir, history)
plot_miou(np.arange(0,args.epochs),args.output_dir, history)
plot_lr(np.arange(0,args.epochs),args.output_dir, history)
4. 开始执行
if __name__ == '__main__':
#获取训练参数
args = get_args_parser()
#解析训练参数
args = args.parse_args()
#训练参数传入主函数
main(args)
运行train.py,训练时打印的信息,包括每一轮的学习率,训练集和验证集指标,运行时间等
Simplify_Net.py
定义网络结构,本文定义一个简单的Encoder-Decoder结构的卷积神经网络
import torch
import torch.nn as nn
from torch.nn.functional import interpolate
class Simplify_Net(nn.Module):
def __init__(self, num_classes=2):
super(Simplify_Net, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,padding=1,stride=2)
self.bn1 = nn.BatchNorm2d(16)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,padding=1,stride=2)
self.bn2 = nn.BatchNorm2d(16)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,padding=1,stride=2)
self.bn3 = nn.BatchNorm2d(16)
self.relu3 = nn.ReLU(inplace=True)
self.upconv1 = nn.ConvTranspose2d(in_channels=16,out_channels=16,kernel_size=4,padding=1,stride=2)
self.bn4 = nn.BatchNorm2d(16)
self.relu4 = nn.ReLU(inplace=True)
self.upconv2 = nn.ConvTranspose2d(in_channels=32,out_channels=16,kernel_size=4,padding=1,stride=2)
self.bn5 = nn.BatchNorm2d(16)
self.relu5 = nn.ReLU(inplace=True)
self.conv_last = nn.Conv2d(in_channels=32,out_channels=num_classes,kernel_size=1,stride=1)
def forward(self, x):
x1 = self.relu1(self.bn1(self.conv1(x)))
x2 = self.relu2(self.bn2(self.conv2(x1)))
x3 = self.relu3(self.bn3(self.conv3(x2)))
up1 = torch.cat([x2,self.relu4(self.bn4(self.upconv1(x3)))],dim=1)
up2 = torch.cat([x1,self.relu5(self.bn5(self.upconv2(up1)))],dim=1)
up3 = self.conv_last(up2)
out = interpolate(up3, scale_factor=2, mode='bilinear', align_corners=False)
return out
utils/datasets.py
定义数据读取的类
import os
from torch.utils.data import Dataset
from PIL import Image
class SegData(Dataset):
def __init__(self, image_path, mask_path, data_transforms=None):
self.image_path = image_path
self.mask_path = mask_path
self.images = os.listdir(self.image_path)
self.masks = os.listdir(self.mask_path)
self.transform = data_transforms
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image_filename = self.images[idx]
mask_filename = image_filename.replace('jpeg','png')
image = Image.open(os.path.join(self.image_path, image_filename)).convert('RGB')
mask = Image.open(os.path.join(self.mask_path, mask_filename)).convert('L')
if self.transform is not None:
image, mask = self.transform(image ,mask)
return image, mask
utils/transform.py
定义数据预处理的类
import numpy as np
import random
import torch
from torchvision.transforms import functional as F
# 将img和mask resize到统一大小
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, image, target=None):
image = F.resize(image, self.size)
if target is not None:
target = F.resize(target, self.size, interpolation=F.InterpolationMode.NEAREST)
return image, target
#随机左右翻转
class RandomHorizontalFlip(object):
def __init__(self, flip_prob):
self.flip_prob = flip_prob
def __call__(self, image, target=None):
if random.random() < self.flip_prob:
image = F.hflip(image)
if target is not None:
target = F.hflip(target)
return image, target
#标准化
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target):
image = F.normalize(image, mean=self.mean, std=self.std)
return image, target
#img转tensor,值变为0-1之间,label直接转为tensor
class ToTensor(object):
def __call__(self, image, target):
image = F.to_tensor(image)
if target is not None:
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
#不同数据预处理类组合起来
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, mask=None):
for t in self.transforms:
image, mask = t(image, mask)
return image, mask
utils/metrics.py
定义计算像素准确率、MIoU等指标的类
import numpy as np
class Evaluator(object):
def __init__(self, num_class):
self.num_class = num_class
self.confusion_matrix = np.zeros((self.num_class,)*2)
#计算像素准确率
def Pixel_Accuracy(self):
Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
return Acc
def Pixel_Accuracy_Class(self):
Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
Acc = np.nanmean(Acc)
return Acc
#计算每一类IoU和MIoU
def Mean_Intersection_over_Union(self):
IoU = np.diag(self.confusion_matrix) / (
np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
np.diag(self.confusion_matrix))
MIoU = np.nanmean(IoU)
return IoU,MIoU
def Frequency_Weighted_Intersection_over_Union(self):
freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
iu = np.diag(self.confusion_matrix) / (
np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
np.diag(self.confusion_matrix))
FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
return FWIoU
def _generate_matrix(self, gt_image, pre_image):
mask = (gt_image >= 0) & (gt_image < self.num_class)
label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
count = np.bincount(label, minlength=self.num_class**2)
confusion_matrix = count.reshape(self.num_class, self.num_class)
return confusion_matrix
#加入数据
def add_batch(self, gt_image, pre_image):
assert gt_image.shape == pre_image.shape
self.confusion_matrix += self._generate_matrix(gt_image, pre_image)
#重置
def reset(self):
self.confusion_matrix = np.zeros((self.num_class,) * 2)
utils/engine.py
定义具体的训练、验证以及绘制指标曲线的函数
1. 导入相应的库和文件
import os
import torch
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils.metrics import Evaluator
import numpy as np
2. 训练验证函数
def train_and_val(epochs, model, train_loader, val_loader,criterion, optimizer,scheduler,output_dir,device,nb_classes):
train_loss = []
val_loss = []
train_pix_acc = []
val_pix_acc = []
train_miou = []
val_miou = []
learning_rate = []
best_miou = 0
segmetric_train = Evaluator(nb_classes)
segmetric_val = Evaluator(nb_classes)
model.to(device)
fit_time = time.time()
for e in range(epochs):
torch.cuda.empty_cache()
segmetric_train.reset()
segmetric_val.reset()
since = time.time()
training_loss = 0
model.train()
with tqdm(total=len(train_loader)) as pbar:
for image, label in train_loader:
image = image.to(device)
label = label.to(device)
output = model(image)
loss = criterion(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = output.data.cpu().numpy()
label = label.cpu().numpy()
pred = np.argmax(pred, axis=1)
training_loss += loss.item()
segmetric_train.add_batch(label, pred)
pbar.update(1)
model.eval()
validation_loss = 0
with torch.no_grad():
with tqdm(total=len(val_loader)) as pb:
for image, label in val_loader:
image = image.to(device)
label = label.to(device)
output = model(image)
loss = criterion(output, label)
pred = output.data.cpu().numpy()
label = label.cpu().numpy()
pred = np.argmax(pred, axis=1)
validation_loss += loss.item()
segmetric_val.add_batch(label, pred)
pb.update(1)
train_loss.append(training_loss / len(train_loader))
val_loss.append(validation_loss / len(val_loader))
train_pix_acc.append(segmetric_train.Pixel_Accuracy())
val_pix_acc.append(segmetric_val.Pixel_Accuracy())
train_miou.append(segmetric_train.Mean_Intersection_over_Union()[1])
val_miou.append(segmetric_val.Mean_Intersection_over_Union()[1])
learning_rate.append(scheduler.get_last_lr())
torch.save(model.state_dict(), os.path.join(output_dir,'last.pth'))
if best_miou < segmetric_val.Mean_Intersection_over_Union()[1]:
torch.save(model.state_dict(), os.path.join(output_dir,'best.pth'))
print("Epoch:{}/{}..".format(e + 1, epochs),
"Train Pix Acc: {:.3f}".format(segmetric_train.Pixel_Accuracy()),
"Val Pix Acc: {:.3f}".format(segmetric_val.Pixel_Accuracy()),
"Train MIoU: {:.3f}".format(segmetric_train.Mean_Intersection_over_Union()[1]),
"Val MIoU: {:.3f}".format(segmetric_val.Mean_Intersection_over_Union()[1]),
"Train Loss: {:.3f}".format(training_loss / len(train_loader)),
"Val Loss: {:.3f}".format(validation_loss / len(val_loader)),
"Time: {:.2f}s".format((time.time() - since)))
scheduler.step()
history = {'train_loss': train_loss, 'val_loss': val_loss ,'train_pix_acc': train_pix_acc, 'val_pix_acc': val_pix_acc,'train_miou': train_miou, 'val_miou': val_miou,'lr':learning_rate}
print('Total time: {:.2f} m'.format((time.time() - fit_time) / 60))
return history
3. 打印损失值曲线
def plot_loss(x,output_dir, history):
plt.plot(x, history['val_loss'], label='val', marker='o')
plt.plot(x, history['train_loss'], label='train', marker='o')
plt.title('Loss per epoch')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(), plt.grid()
plt.savefig(os.path.join(output_dir,'loss.png'))
plt.clf()
4. 打印像素准确率曲线
def plot_pix_acc(x,output_dir, history):
plt.plot(x, history['train_pix_acc'], label='train_pix_acc', marker='x')
plt.plot(x, history['val_pix_acc'], label='val_pix_acc', marker='x')
plt.title('Pix Acc per epoch')
plt.ylabel('pixal accuracy')
plt.xlabel('epoch')
plt.legend(), plt.grid()
plt.savefig(os.path.join(output_dir,'pix_acc.png'))
plt.clf()
网络结构较为简单,因此像素准确率不是特别的高
5. 打印MIoU曲线
def plot_miou(x,output_dir, history):
plt.plot(x, history['train_miou'], label='train_miou', marker='x')
plt.plot(x, history['val_miou'], label='val_miou', marker='x')
plt.title('MIoU per epoch')
plt.ylabel('miou')
plt.xlabel('epoch')
plt.legend(), plt.grid()
plt.savefig(os.path.join(output_dir,'miou.png'))
plt.clf()
网络结构较为简单,因此像MIoU不是特别的高
6. 打印学习率曲线
def plot_lr(x,output_dir, history):
plt.plot(x, history['lr'], label='learning_rate', marker='x')
plt.title('learning rate per epoch')
plt.ylabel('Learning_rate')
plt.xlabel('epoch')
plt.legend(), plt.grid()
plt.savefig(os.path.join(output_dir,'learning_rate.png'))
plt.clf()
从学习率曲线可以看出,约前30轮为warmup阶段,最大学习率为0.001
predict.py
进行单张图片预测
1. 导入相应的库和文件
import argparse
import torch
import torch.nn as nn
import torchvision.transforms as T
from models.Simplify_Net import Simplify_Net
from PIL import Image
2. 单张预测参数设置
def get_args_parser():
parser = argparse.ArgumentParser('Predict Image', add_help=False)
parser.add_argument('--image_path', default='./people-man-model-glasses-46219.jpeg', type=str, metavar='MODEL',help='Name of model to train')
parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
parser.add_argument('--weights', default='./output_dir/last.pth', type=str,help='dataset path')
parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')
parser.add_argument('--device', default='cuda',help='device to use for training / testing')
return parser
3. 定义主函数
def main(args):
device = torch.device(args.device)
image = Image.open(args.image_path).convert('RGB')
img_size = image.size
transforms = T.Compose([
T.Resize(args.input_size),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
model = Simplify_Net(args.nb_classes)
checkpoint = torch.load(args.weights, map_location='cpu')
msg = model.load_state_dict(checkpoint, strict=True)
print(msg)
model.to(device)
model.eval()
input_tensor = transforms(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
pred = output.argmax(1).squeeze(0).cpu().numpy().astype(np.uint8)
mask = Image.fromarray(pred)
out = mask.resize(img_size)
out.save("result.png")
4. 开始执行
if __name__ == '__main__':
#获取训练参数
args = get_args_parser()
#解析训练参数
args = args.parse_args()
#训练参数传入主函数
main(args)
运行predict.py,保存模型预测的结果
eval.py
进行模型整体指标评价
1. 导入相应的库和文件
import argparse
from utils.transform import Resize,Compose,ToTensor,Normalize,RandomHorizontalFlip
from utils.datasets import SegData
import torch
import os
import numpy as np
from tqdm import tqdm
from models.Simplify_Net import Simplify_Net
from utils.metrics import Evaluator
2. 模型评价参数设置
def get_args_parser():
parser = argparse.ArgumentParser('Eval Model', add_help=False)
parser.add_argument('--batch_size', default=1, type=int,help='Batch size for training')
parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
parser.add_argument('--data_path', default='./datasets/', type=str,help='dataset path')
parser.add_argument('--weights', default='./output_dir/best.pth', type=str,help='dataset path')
parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')
parser.add_argument('--device', default='cuda',help='device to use for training / testing')
parser.add_argument('--num_workers', default=4, type=int)
return parser
3. 定义主函数
def main(args):
device = torch.device(args.device)
segmetric = Evaluator(args.nb_classes)
segmetric.reset()
val_transform = Compose([
Resize(args.input_size),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_dataset = SegData(image_path=os.path.join(args.data_path, 'images/val'),
mask_path=os.path.join(args.data_path, 'labels/val'),
data_transforms=val_transform)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers)
model = Simplify_Net(args.nb_classes)
checkpoint = torch.load(args.weights, map_location='cpu')
msg = model.load_state_dict(checkpoint, strict=True)
print(msg)
model.to(device)
model.eval()
classes = ["background","human"]
with torch.no_grad():
with tqdm(total=len(val_loader)) as pbar:
for image, label in val_loader:
output = model(image.to(device))
pred = output.data.cpu().numpy()
label = label.cpu().numpy()
pred = np.argmax(pred, axis=1)
segmetric.add_batch(label, pred)
pbar.update(1)
pix_acc = segmetric.Pixel_Accuracy()
every_iou,miou = segmetric.Mean_Intersection_over_Union()
print("Pixel Accuracy is :", pix_acc)
print("==========Every IOU==========")
for name,prob in zip(classes,every_iou):
print(name+" : "+str(prob))
print("=============================")
print("MiOU is :", miou)
4. 开始执行
if __name__ == '__main__':
#获取训练参数
args = get_args_parser()
#解析训练参数
args = args.parse_args()
#训练参数传入主函数
main(args)
运行eval.py,打印模型在验证集上的像素准确率,MIoU值和每一类的IoU值
export_onnx.py
将训练好的模型转onnx格式,以进行后续应用
1. 导入相应的库和文件
import torch
from models.Simplify_Net import Simplify_Net
import argparse
2. 转onnx模型参数设置
def get_args_parser():
parser = argparse.ArgumentParser('Export Onnx', add_help=False)
parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
parser.add_argument('--weights', default='./output_dir/best.pth', type=str,help='dataset path')
parser.add_argument('--nb_classes', default=2, type=int,help='number of the classification types')
return parser
3. 定义主函数
def main(args):
x = torch.randn(1, 3, args.input_size[0],args.input_size[1])
input_names = ["input"]
out_names = ["output"]
model = Simplify_Net(args.nb_classes)
checkpoint = torch.load(args.weights, map_location='cpu')
msg = model.load_state_dict(checkpoint, strict=True)
print(msg)
model.eval()
torch.onnx.export(model, x, args.weights.replace('pth','onnx'), export_params=True, training=False, input_names=input_names, output_names=out_names)
print('please run: python -m onnxsim test.onnx test_sim.onnx\n')
4. 开始执行
if __name__ == '__main__':
#获取训练参数
args = get_args_parser()
#解析训练参数
args = args.parse_args()
#训练参数传入主函数
main(args)
运行export_onnx.py,之后进行模型的简化
简化之前(左)和之后(右)的onnx模型结构对比