mmsegmentationV1.0训练分割模型(自定义数据格式和config)

MMSegmentation 保持了 MM 系列一贯的风格,拥有灵活的模块化设计和全面的高性能model zoo。目前我们支持非常多的主流backbone和语义分割算法,支持多种数据集如 Cityscapes,ADE20K,Pascal VOC 2012上的训练结果(目前应该是语义分割中最大的 模型库)。

一. 数据准备

这个根据自己的数据而定,在mmsegmentataion里,任意算法几乎都提供了不同的数据集训练的预训练模型,有两种准备数据的方式。

1. 把自己的数据修改成匹配到对应数据集格式。如cityscope格式。

2. 自定义自己的数据格式。

本文以自定义自己的数据格式为例,记录mmsegmentataion的训练过程。

参考:新增自定义数据集 — MMSegmentation 1.1.0 文档

把自己的数据集做成下列格式:

├── data
│   ├── my_dataset
│   │   ├── img_dir
│   │   │   ├── train
│   │   │   │   ├── xxx{img_suffix}
│   │   │   │   ├── yyy{img_suffix}
│   │   │   │   ├── zzz{img_suffix}
│   │   │   ├── val
│   │   ├── ann_dir
│   │   │   ├── train
│   │   │   │   ├── xxx{seg_map_suffix}
│   │   │   │   ├── yyy{seg_map_suffix}
│   │   │   │   ├── zzz{seg_map_suffix}
│   │   │   ├── val
注意: 标注是跟图像同样的形状 (H, W),其中的像素值的范围是 [0, num_classes - 1]。 您也可以使用 pillow 的 'P' 模式去创建包含颜色的标注。

数据整理好后,要创建几个文件:

1. 创建一个新文件 mmseg/datasets/example.py

from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset


@DATASETS.register_module()
class ExampleDataset(BaseSegDataset):

    METAINFO = dict(
        classes=('xxx', 'xxx', ...),
        palette=[[x, x, x], [x, x, x], ...])

    def __init__(self, aeg1, arg2):
        pass

新建自己的一个数据集类名,如ExampleDataset,继承自mmsegmentation的BaseSegDataSet。BaseSegDataSet是mmsegmentation的内置类,描述了数据的一些通用方法和属性。

2. 在 mmseg/datasets/__init__.py 中导入模块

# 顶端插入下列代码
from .example import ExampleDataset


# 后面加上 ExampleDataset
__all__ = [
    ..., 'ExampleDataset'
]

3. mmseg/utils/class_names.py 中补充数据集元信息

def example_classes():
    return [
        'xxx', 'xxx',
        ...
    ]

def example_palette():
    return [
        [x, x, x], [x, x, x],
        ...
    ]
dataset_aliases ={
    'example': ['example', ...],
    ...
}

 注意: 虽然我们这里定义了ExampleDataset,但后续运行训练代码时,可能出现没有登记数据集的报错,报错信息是:

KeyError: 'ExampleDataset is not in the dataset registry.

此时只需要在mmsegmentataion文件下运行一下下列命令就可解决

python setup.py install

4. 通过创建一个新的数据集配置文件 configs/_base_/datasets/example_dataset.py 来使用它

dataset_type = 'ExampleDataset'
data_root = 'data/example/'
...

上述几步骤的流程可以理解为:先定义一个ExampleDataset数据集的类,并添加好classes,palette等基本信息, config会通过example_dataset.py内的dataset_type参数来找到数据类的定义,然后实例化一个对象。

二. config文件构建

config文件是mmsegmentation内训练时的最重要文件,通过继承的方式一层层嵌套,定义了数据、网络、训练策略、默认设置(日志,可视化)四个部分。

1. 在configs/_base_/models/内定义了不同算法模型的网络文件,如pspnet_unet_s5-d16.py, 里面主要包括几个重要参数

  • data_preprocessor:数据预处理字典
  • model: 模型定义字典。包含了模型的norm_cfg、backbone、decode_head、auxiliary_head
  • train_cfg  、test_cfg :模型训练和测试设置

2. 在configs/_base_/datasets里定义了一些数据集,如cityscapes.py,里面主要包括下列几个参数:

  • dataset_type: 数据类型,会根据这个字符串映射到数据类
  • data_root: 数据图像和标注文件的根路径
  • crop_size: 图像输入网络的尺寸
  • train_pipeline: 训练流程(加载图像,加载标注、图像增强)
  • test_pipeline: 测试流程
  • train_dataloader: 训练集加载器,内部包括batch_size, dataset, pipeline
  • val_dataloader: 验证集加载器
  • val_evaluator: 验证集计算方法:IOU或Dice

3. 在configs/_base_/schedules内定义了不同的训练策略,主要包括optimizer, max_iters等

4. 在configs/_base_/default_runtime.py 定义了一些通用信息,如环境配置,可视化配置,还有load_from(预训练模型路径)

在configs文件夹下,不同算法模型几乎都是经过2层或3层继承_base_里的响应模块之后,然后修改参数得到的。

同理,我们这里使用Unet自己写一个config文件my_unet.py, config顶层文件在修改继承来的参数时,只需要重写对应字典的键值就好,未修改的保留继承来自底层文件的值。

_base_ = [
    '../_base_/models/pspnet_unet_s5-d16.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]

# 1. 数据集设置
dataset_type = 'ExampleDataset'
data_root = 'data/crack220p_432x432/'
img_scale = (432, 432)
crop_size = (432, 432)
data_preprocessor = dict(size=crop_size)


# 2. 模型设置
norm_cfg = dict(type='BN', requires_grad=True) # 单GPU训练用BN, 多GPU训练用SyncBN
model = dict(
    data_preprocessor = data_preprocessor,
    backbone = dict(norm_cfg=norm_cfg),
    decode_head = dict(num_classes=2, norm_cfg=norm_cfg), # 输出为2个类别
    auxiliary_head = dict(num_classes=2, norm_cfg=norm_cfg),
    test_cfg = dict(crop_size=crop_size, stride=(170, 170)) 
    )


# 3. 训练流程
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(
        type='RandomResize',
        scale=img_scale,
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='PackSegInputs')
]


# 4. 测试流程
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=img_scale, keep_ratio=True),
    # add loading annotation after ``Resize`` because ground truth
    # does not need to do resize data transform
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs')
]

# 5. tta流程,一般不用改
# 6. 数据加载器
train_dataloader = dict(
    batch_size=2,   # mmseg要去必须>=2
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='img_dir/train', seg_map_path='ann_dir/train'),
        pipeline=train_pipeline
        ))

val_dataloader = dict(
    batch_size=1,  # mmseg要求必须为1
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path='img_dir/val', seg_map_path='ann_dir/val'),
        pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mDice'])
test_evaluator = val_evaluator

# 7.加载预训练模型
load_from = 'pretrain/pspnet_unet_s5-d16_256x256_40k_hrf_20201227_181818-fdb7e29b.pth'

# 8. 训练策略
train_cfg = dict(type='IterBasedTrainLoop', max_iters=10000, val_interval=1000)
default_hooks = dict(
    logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
    checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=1000))

三、训练

config文件配置好后,训练就很简单了,一行代码的事。

python tools/train.py  ${配置文件} [可选参数]
  • --work-dir ${工作路径}: 重新指定工作路径

  • --amp: 使用自动混合精度计算

  • --resume: 从工作路径中保存的最新检查点文件(checkpoint)恢复训练

  • --cfg-options ${需更覆盖的配置}: 覆盖已载入的配置中的部分设置,并且 以 xxx=yyy 格式的键值对 将被合并到配置文件中。 比如: ‘–cfg-option model.encoder.in_channels=6’, 更多细节请看指导。

以上。

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值