MMSegmentation 保持了 MM 系列一贯的风格,拥有灵活的模块化设计和全面的高性能model zoo。目前我们支持非常多的主流backbone和语义分割算法,支持多种数据集如 Cityscapes,ADE20K,Pascal VOC 2012上的训练结果(目前应该是语义分割中最大的 模型库)。
一. 数据准备
这个根据自己的数据而定,在mmsegmentataion里,任意算法几乎都提供了不同的数据集训练的预训练模型,有两种准备数据的方式。
1. 把自己的数据修改成匹配到对应数据集格式。如cityscope格式。
2. 自定义自己的数据格式。
本文以自定义自己的数据格式为例,记录mmsegmentataion的训练过程。
参考:新增自定义数据集 — MMSegmentation 1.1.0 文档
把自己的数据集做成下列格式:
├── data
│ ├── my_dataset
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{img_suffix}
│ │ │ │ ├── yyy{img_suffix}
│ │ │ │ ├── zzz{img_suffix}
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{seg_map_suffix}
│ │ │ │ ├── yyy{seg_map_suffix}
│ │ │ │ ├── zzz{seg_map_suffix}
│ │ │ ├── val
注意: 标注是跟图像同样的形状 (H, W),其中的像素值的范围是 [0, num_classes - 1]
。 您也可以使用 pillow 的 'P'
模式去创建包含颜色的标注。
数据整理好后,要创建几个文件:
1. 创建一个新文件 mmseg/datasets/example.py
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()
class ExampleDataset(BaseSegDataset):
METAINFO = dict(
classes=('xxx', 'xxx', ...),
palette=[[x, x, x], [x, x, x], ...])
def __init__(self, aeg1, arg2):
pass
新建自己的一个数据集类名,如ExampleDataset,继承自mmsegmentation的BaseSegDataSet。BaseSegDataSet是mmsegmentation的内置类,描述了数据的一些通用方法和属性。
2. 在 mmseg/datasets/__init__.py
中导入模块
# 顶端插入下列代码
from .example import ExampleDataset
# 后面加上 ExampleDataset
__all__ = [
..., 'ExampleDataset'
]
3. mmseg/utils/class_names.py
中补充数据集元信息
def example_classes():
return [
'xxx', 'xxx',
...
]
def example_palette():
return [
[x, x, x], [x, x, x],
...
]
dataset_aliases ={
'example': ['example', ...],
...
}
注意: 虽然我们这里定义了ExampleDataset,但后续运行训练代码时,可能出现没有登记数据集的报错,报错信息是:
KeyError: 'ExampleDataset is not in the dataset registry.
此时只需要在mmsegmentataion文件下运行一下下列命令就可解决
python setup.py install
4. 通过创建一个新的数据集配置文件 configs/_base_/datasets/example_dataset.py
来使用它
dataset_type = 'ExampleDataset'
data_root = 'data/example/'
...
上述几步骤的流程可以理解为:先定义一个ExampleDataset数据集的类,并添加好classes,palette等基本信息, config会通过example_dataset.py内的dataset_type参数来找到数据类的定义,然后实例化一个对象。
二. config文件构建
config文件是mmsegmentation内训练时的最重要文件,通过继承的方式一层层嵌套,定义了数据、网络、训练策略、默认设置(日志,可视化)四个部分。
1. 在configs/_base_/models/内定义了不同算法模型的网络文件,如pspnet_unet_s5-d16.py, 里面主要包括几个重要参数
- data_preprocessor:数据预处理字典
- model: 模型定义字典。包含了模型的norm_cfg、backbone、decode_head、auxiliary_head
- train_cfg 、test_cfg :模型训练和测试设置
2. 在configs/_base_/datasets里定义了一些数据集,如cityscapes.py,里面主要包括下列几个参数:
- dataset_type: 数据类型,会根据这个字符串映射到数据类
- data_root: 数据图像和标注文件的根路径
- crop_size: 图像输入网络的尺寸
- train_pipeline: 训练流程(加载图像,加载标注、图像增强)
- test_pipeline: 测试流程
- train_dataloader: 训练集加载器,内部包括batch_size, dataset, pipeline
- val_dataloader: 验证集加载器
- val_evaluator: 验证集计算方法:IOU或Dice
3. 在configs/_base_/schedules内定义了不同的训练策略,主要包括optimizer, max_iters等
4. 在configs/_base_/default_runtime.py 定义了一些通用信息,如环境配置,可视化配置,还有load_from(预训练模型路径)
在configs文件夹下,不同算法模型几乎都是经过2层或3层继承_base_里的响应模块之后,然后修改参数得到的。
同理,我们这里使用Unet自己写一个config文件my_unet.py, config顶层文件在修改继承来的参数时,只需要重写对应字典的键值就好,未修改的保留继承来自底层文件的值。
_base_ = [
'../_base_/models/pspnet_unet_s5-d16.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]
# 1. 数据集设置
dataset_type = 'ExampleDataset'
data_root = 'data/crack220p_432x432/'
img_scale = (432, 432)
crop_size = (432, 432)
data_preprocessor = dict(size=crop_size)
# 2. 模型设置
norm_cfg = dict(type='BN', requires_grad=True) # 单GPU训练用BN, 多GPU训练用SyncBN
model = dict(
data_preprocessor = data_preprocessor,
backbone = dict(norm_cfg=norm_cfg),
decode_head = dict(num_classes=2, norm_cfg=norm_cfg), # 输出为2个类别
auxiliary_head = dict(num_classes=2, norm_cfg=norm_cfg),
test_cfg = dict(crop_size=crop_size, stride=(170, 170))
)
# 3. 训练流程
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=img_scale,
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
# 4. 测试流程
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=img_scale, keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
# 5. tta流程,一般不用改
# 6. 数据加载器
train_dataloader = dict(
batch_size=2, # mmseg要去必须>=2
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='img_dir/train', seg_map_path='ann_dir/train'),
pipeline=train_pipeline
))
val_dataloader = dict(
batch_size=1, # mmseg要求必须为1
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='img_dir/val', seg_map_path='ann_dir/val'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mDice'])
test_evaluator = val_evaluator
# 7.加载预训练模型
load_from = 'pretrain/pspnet_unet_s5-d16_256x256_40k_hrf_20201227_181818-fdb7e29b.pth'
# 8. 训练策略
train_cfg = dict(type='IterBasedTrainLoop', max_iters=10000, val_interval=1000)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=1000))
三、训练
config文件配置好后,训练就很简单了,一行代码的事。
python tools/train.py ${配置文件} [可选参数]
-
--work-dir ${工作路径}
: 重新指定工作路径 -
--amp
: 使用自动混合精度计算 -
--resume
: 从工作路径中保存的最新检查点文件(checkpoint)恢复训练 -
--cfg-options ${需更覆盖的配置}
: 覆盖已载入的配置中的部分设置,并且 以 xxx=yyy 格式的键值对 将被合并到配置文件中。 比如: ‘–cfg-option model.encoder.in_channels=6’, 更多细节请看指导。
以上。