本篇博客是《手把手实战教学!语义分割从0到1》系列的第二篇实战教学——模型的训练,将首先介绍一些常见模型,然后重点讲解如何使用自己的数据集训练一个语义分割模型。模型训练主要参考了这个开源库:GitHub - qq995431104/pytorch_segmentation: Semantic segmentation models, datasets and losses implemented in PyTorch.。
本系列总的介绍,以及其他章节的汇总,见:手把手实战教学!语义分割从0到1:开篇_AI数据工厂-CSDN博客_语义分割实战。
目录
1、常用的语义分割网络
从FCN开始,语义分割正式进入了深度学习时代,此后,U-Net、SegNet、PSPNet、DeepLab系列各种网络层出不穷。如果需要进一步了解各个网络的相关知识点,可以参考我的专栏:https://blog.csdn.net/oyezhou/category_10704356.html,该专栏包含了语义分割多个网络的介绍以及其他知识点。
我们本篇博客将利用DeepLabV3+进行实战。
2、训练自己的语义分割模型
2.1、数据准备
在本系列博客的上一篇《手把手实战教学!语义分割从0到1:一、数据集制作》介绍了如何制作语义分割数据集,如果按照上面的说明一步一步走,应该现在已经有了一批标注并转换好的VOC格式的分割数据集。我们把做好的数据集放到某个目录下备用。
2.2、代码准备
先从GitHub - qq995431104/pytorch_segmentation: Semantic segmentation models, datasets and losses implemented in PyTorch.把我们用到的开源库git clone下来。然后,按照requirements的要求安装好对应的库。
该开源代码包含了多个分割网络,如FCN、U-Net、PSPNet、DeepLabv3+,均可通过配置相应参数来使用。
2.3、修改配置文件
该开源代码为了统一配置我们的训练参数,做了一个配置文件(pytorch_segmentation/config.json),里面可以配置网络的backbone、分割模型、数据集、优化器、loss以及其他超参。
我们这里使用的数据集是VOC类型的,然后用的模型为DeepLabV3+,总的配置如下:
{
"name": "DeepLabv3_plus",
"n_gpu": 1,
"use_synch_bn": true,
"arch": {
"type": "DeepLab",
"args": {
"backbone": "resnet101",
"freeze_bn": false,
"freeze_backbone": false
}
},
"train_loader": {
"type": "MyVOC",
"args":{
"data_dir": "D:/dataset/my_dataset",
"batch_size": 4,
"base_size": 718,
"crop_size": 718,
"augment": true,
"shuffle": true,
"scale": true,
"flip": true,
"rotate": true,
"blur": false,
"split": "train",
"num_workers": 0
}
},
"val_loader": {
"type": "MyVOC",
"args":{
"data_dir": "D:/dataset/my_dataset",
"batch_size": 4,
"crop_size": 718,
"val": true,
"split": "val",
"num_workers": 0
}
},
"optimizer": {
"type": "SGD",
"differential_lr": true,
"args":{
"lr": 0.005,
"weight_decay": 1e-4,
"momentum": 0.99
}
},
"loss": "CrossEntropyLoss2d",
"ignore_index": 255,
"lr_scheduler": {
"type": "Poly",
"args": {}
},
"trainer": {
"epochs": 120,
"save_dir": "saved/",
"save_period": 10,
"monitor": "max Mean_IoU",
"early_stop": 20,
"tensorboard": true,
"log_dir": "saved/runs",
"log_per_iter": 10,
"val": true,
"val_per_epochs": 5
}
}
可以参考我上面的配置,自定义你自己的训练配置。
2.4、Dataset及DataLoader
在“pytorch_segmentation/dataloaders”目录下有几种常见数据集的DataLoader定义,如VOC、COCO等,我们这里由于使用的是VOC格式的数据集,所以可以基于“voc.py”这个文件进行修改。这里贴出我修改后的dataset和DataLoader定义:
# Originally written by Kazuto Nakashima
# https://github.com/kazuto1011/deeplab-pytorch
from base import BaseDataSet, BaseDataLoader
from utils import palette
import numpy as np
import os
from PIL import Image
class VOCDataset(BaseDataSet):
"""
my VOC-like dataset
"""
def __init__(self, **kwargs):
self.num_classes = 2
self.palette = palette.get_voc_palette(self.num_classes)
super(VOCDataset, self).__init__(**kwargs)
def _set_files(self):
self.image_dir = os.path.join(self.root, 'JPEGImages')
self.label_dir = os.path.join(self.root, 'SegmentationClass')
file_list = os.path.join(self.root, self.split + ".txt")
self.files = [line.rstrip() for line in tuple(open(file_list, "r"))]
def _load_data(self, index):
image_id = self.files[index]
image_path = os.path.join(self.image_dir, image_id + '.jpg')
label_path = os.path.join(self.label_dir, image_id + '.png')
image = np.asarray(Image.open(image_path).convert('RGB'), dtype=np.float32)
label = np.asarray(Image.open(label_path), dtype=np.int32)
image_id = self.files[index].split("/")[-1].split(".")[0]
return image, label, image_id
class MyVOC(BaseDataLoader):
def __init__(self, data_dir, batch_size, split, crop_size=None, base_size=None, scale=True, num_workers=1, val=False,
shuffle=False, flip=False, rotate=False, blur= False, augment=False, val_split= None, return_id=False):
# update at 2021.01.29
self.MEAN = [0.4935838, 0.48873937, 0.45739236]
self.STD = [0.22273207, 0.22567303, 0.22986929]
kwargs = {
'root': data_dir,
'split': split,
'mean': self.MEAN,
'std': self.STD,
'augment': augment,
'crop_size': crop_size,
'base_size': base_size,
'scale': scale,
'flip': flip,
'blur': blur,
'rotate': rotate,
'return_id': return_id,
'val': val
}
self.dataset = VOCDataset(**kwargs)
super(MyVOC, self).__init__(self.dataset, batch_size, shuffle, num_workers, val_split)
其中,palette是每个类别的颜色定义,我是用的voc的定义,你也可以自行定义每个类别的颜色:
palette.py:
def get_voc_palette(num_classes):
n = num_classes
palette = [0]*(n*3)
for j in range(0,n):
lab = j
palette[j*3+0] = 0
palette[j*3+1] = 0
palette[j*3+2] = 0
i = 0
while (lab > 0):
palette[j*3+0] |= (((lab >> 0) & 1) << (7-i))
palette[j*3+1] |= (((lab >> 1) & 1) << (7-i))
palette[j*3+2] |= (((lab >> 2) & 1) << (7-i))
i = i + 1
lab >>= 3
return palette
ADE20K_palette = [0,0,0,120,120,120,180,120,120,6,230,230,80,50,50,4,200,
3,120,120,80,140,140,140,204,5,255,230,230,230,4,250,7,224,
5,255,235,255,7,150,5,61,120,120,70,8,255,51,255,6,82,143,
255,140,204,255,4,255,51,7,204,70,3,0,102,200,61,230,250,255,
6,51,11,102,255,255,7,71,255,9,224,9,7,230,220,220,220,255,9,
92,112,9,255,8,255,214,7,255,224,255,184,6,10,255,71,255,41,
10,7,255,255,224,255,8,102,8,255,255,61,6,255,194,7,255,122,8,
0,255,20,255,8,41,255,5,153,6,51,255,235,12,255,160,150,20,0,
163,255,140,140,140,250,10,15,20,255,0,31,255,0,255,31,0,255,224
,0,153,255,0,0,0,255,255,71,0,0,235,255,0,173,255,31,0,255,11,200,
200,255,82,0,0,255,245,0,61,255,0,255,112,0,255,133,255,0,0,255,
163,0,255,102,0,194,255,0,0,143,255,51,255,0,0,82,255,0,255,41,0,
255,173,10,0,255,173,255,0,0,255,153,255,92,0,255,0,255,255,0,245,
255,0,102,255,173,0,255,0,20,255,184,184,0,31,255,0,255,61,0,71,255,
255,0,204,0,255,194,0,255,82,0,10,255,0,112,255,51,0,255,0,194,255,0,
122,255,0,255,163,255,153,0,0,255,10,255,112,0,143,255,0,82,0,255,163,
255,0,255,235,0,8,184,170,133,0,255,0,255,92,184,0,255,255,0,31,0,184,
255,0,214,255,255,0,112,92,255,0,0,224,255,112,224,255,70,184,160,163,
0,255,153,0,255,71,255,0,255,0,163,255,204,0,255,0,143,0,255,235,133,255,
0,255,0,235,245,0,255,255,0,122,255,245,0,10,190,212,214,255,0,0,204,255,
20,0,255,255,255,0,0,153,255,0,41,255,0,255,204,41,0,255,41,255,0,173,0,
255,0,245,255,71,0,255,122,0,255,0,255,184,0,92,255,184,255,0,0,133,255,
255,214,0,25,194,194,102,255,0,92,0,255]
CityScpates_palette = [128,64,128,244,35,232,70,70,70,102,102,156,190,153,153,153,153,153,
250,170,30,220,220,0,107,142,35,152,251,152,70,130,180,220,20,60,255,0,0,0,0,142,
0,0,70,0,60,100,0,80,100,0,0,230,119,11,32,128,192,0,0,64,128,128,64,128,0,192,
128,128,192,128,64,64,0,192,64,0,64,192,0,192,192,0,64,64,128,192,64,128,64,192,
128,192,192,128,0,0,64,128,0,64,0,128,64,128,128,64,0,0,192,128,0,192,0,128,192,
128,128,192,64,0,64,192,0,64,64,128,64,192,128,64,64,0,192,192,0,192,64,128,192,
192,128,192,0,64,64,128,64,64,0,192,64,128,192,64,0,64,192,128,64,192,0,192,192,
128,192,192,64,64,64,192,64,64,64,192,64,192,192,64,64,64,192,192,64,192,64,192,
192,192,192,192,32,0,0,160,0,0,32,128,0,160,128,0,32,0,128,160,0,128,32,128,128,
160,128,128,96,0,0,224,0,0,96,128,0,224,128,0,96,0,128,224,0,128,96,128,128,224,
128,128,32,64,0,160,64,0,32,192,0,160,192,0,32,64,128,160,64,128,32,192,128,160,
192,128,96,64,0,224,64,0,96,192,0,224,192,0,96,64,128,224,64,128,96,192,128,224,
192,128,32,0,64,160,0,64,32,128,64,160,128,64,32,0,192,160,0,192,32,128,192,160,
128,192,96,0,64,224,0,64,96,128,64,224,128,64,96,0,192,224,0,192,96,128,192,224,
128,192,32,64,64,160,64,64,32,192,64,160,192,64,32,64,192,160,64,192,32,192,192,
160,192,192,96,64,64,224,64,64,96,192,64,224,192,64,96,64,192,224,64,192,96,192,
192,224,192,192,0,32,0,128,32,0,0,160,0,128,160,0,0,32,128,128,32,128,0,160,128,
128,160,128,64,32,0,192,32,0,64,160,0,192,160,0,64,32,128,192,32,128,64,160,128,
192,160,128,0,96,0,128,96,0,0,224,0,128,224,0,0,96,128,128,96,128,0,224,128,128,
224,128,64,96,0,192,96,0,64,224,0,192,224,0,64,96,128,192,96,128,64,224,128,192,
224,128,0,32,64,128,32,64,0,160,64,128,160,64,0,32,192,128,32,192,0,160,192,128,
160,192,64,32,64,192,32,64,64,160,64,192,160,64,64,32,192,192,32,192,64,160,192,
192,160,192,0,96,64,128,96,64,0,224,64,128,224,64,0,96,192,128,96,192,0,224,192,
128,224,192,64,96,64,192,96,64,64,224,64,192,224,64,64,96,192,192,96,192,64,224,
192,192,224,192,32,32,0,160,32,0,32,160,0,160,160,0,32,32,128,160,32,128,32,160,
128,160,160,128,96,32,0,224,32,0,96,160,0,224,160,0,96,32,128,224,32,128,96,160,
128,224,160,128,32,96,0,160,96,0,32,224,0,160,224,0,32,96,128,160,96,128,32,224,
128,160,224,128,96,96,0,224,96,0,96,224,0,224,224,0,96,96,128,224,96,128,96,224,
128,224,224,128,32,32,64,160,32,64,32,160,64,160,160,64,32,32,192,160,32,192,32,
160,192,160,160,192,96,32,64,224,32,64,96,160,64,224,160,64,96,32,192,224,32,192,
96,160,192,224,160,192,32,96,64,160,96,64,32,224,64,160,224,64,32,96,192,160,96,
192,32,224,192,160,224,192,96,96,64,224,96,64,96,224,64,224,224,64,96,96,192,224,
96,192,96,224,192,0,0,0]
COCO_palette = [31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227,
119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44,
214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207,
31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75,
227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44,
214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189,
34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75,
227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127,
14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189,
34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103,
189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127,
14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127
, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103,
189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14,
44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127,
127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189,
140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44,
160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190,
207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194,
127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148,
103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127,
14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34,
23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227,
119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39,
40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119,
180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127,
127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14]
2.5、开始训练
当以上步骤完成后,即可开始训练了。从“train.py”启动训练:
python train.py --config config.json
或者在Pycharm中直接run。
2.6、查看训练状态
训练过程中,可以利用tensorboard来查看训练状态:
tensorboard --logdir saved
3、下篇预告
本篇介绍了模型的训练,那么后续我们使用的时候则需要编写推理代码。虽然这个开源库也提供了推理代码,不过我们要做的是把整个工程的推理部分抽离出来,单独形成一个小而紧凑的工程,只进行推理与可视化等内容。下一篇,也即本系列最后一篇博客,将重点介绍如何把这个推理过程抽离出来,并形成一个精简的工程,以供项目上使用。