系列文章目录
文章目录
- 系列文章目录
- 前言
- 一、mmsegmentation
- 1.[基础的环境配置](https://mmsegmentation.readthedocs.io/zh-cn/latest/migration/interface.html?highlight=default_scope#id8)
- 2.[可视化训练](https://mmsegmentation.readthedocs.io/zh-cn/latest/user_guides/visualization_feature_map.html?highlight=TensorboardVisBackend)-启用Tensorboard
- 3.日志
- 3.是否启用断点训练:应该只支持epoch训练模式
- 4.[测试增强,在测试阶段进行数据增强,比如多尺度数据增强(TEST TIME AUGMENTATION)](https://mmengine.readthedocs.io/zh-cn/latest/advanced_tutorials/test_time_augmentation.html?highlight=tta_model)
- 5.自定义数据类型,一般对应着改一下AD20k的数据读取就可以,修改类别与读取文件后缀
- 6. 数据预处理
- 7. 模型,根据segmentors文件夹下定义的语义分割(模型)框架,如编码解码结构
- 8.优化器
- 9.训练采用的参数设置,包括训练方式(迭代 or epoch)
- 10.默认采用的钩子函数,包括训练日志设置、参数保存设置、以及可视化相关设置等
- 11.训练与测试数据流:数据增强
- 12.训练数据加载器:train_dataloader
- 13.[精度评价需要计算的指标](https://mmsegmentation.readthedocs.io/zh-cn/latest/advanced_guides/evaluation.html?highlight=IoUMetric#ioumetric),可选项包括 ‘mIoU’、’mDice’ 和 ‘mFscore’。
前言
mmsegmentation实用参数配置文件,主要包括基础环境配置,学习率调整,优化器、数据流;对模型部分进行了删减与简化
一、mmsegmentation
1.基础的环境配置
#default_scope:搜索所有注册模块的起点
```cdefault_scope='mmseg'
#环境配置
#是否启用 cudnn_benchmark
env_cfg=dict(
cudnn_benchmark=True,
#设置多进程参数
mp_cfg=dict(mp_start_method='fork',opencv_num_threads=0),
# 设置分布式参数
dist_cfg=dict(backend='nccl'))
2.可视化训练-启用Tensorboard
#可视化训练还包括WANDB
visualizer=dict(
type='SegLocalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend'),
],
name='visualizer')
3.日志
log_processor=dict(by_epoch=False)
log_level='INFO'
3.是否启用断点训练:应该只支持epoch训练模式
load_from=None
#断点训练
resume=False
#是否开启混合精度训练
#amp = False
#工作路径
work_dir='D:\\Changdong\\mmsegmentation\\work_dir'
4.测试增强,在测试阶段进行数据增强,比如多尺度数据增强(TEST TIME AUGMENTATION)
tta_model=dict(type='SegTTAModel')
5.自定义数据类型,一般对应着改一下AD20k的数据读取就可以,修改类别与读取文件后缀
dataset_type='GLCDataset'
data_root='D:\\LanduseDataset'
crop_size=(
384,
384,
)
6. 数据预处理
data_preprocessor=dict(
type='SegDataPreProcessor',
mean=[
123.675,
116.28,
103.53,
103.53,
],
std=[58.395,57.12,57.375,57.375,
],
#多波段遥感图像读取不能进行彩色变换
bgr_to_rgb=False,
pad_val=0,
seg_pad_val=255,
size=(384,384,))
7. 模型,根据segmentors文件夹下定义的语义分割(模型)框架,如编码解码结构
num_classes=4
model=dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
#孪生TF主干网络
#在主干网络中需要设置channel
backbone=dict(
#冻结训练设置:0-3可以进行冻结设置
frozen_stages=-1,
# 模型采用预训练权重进行初始化
#init_cfg=dict(
#type='Pretrained',
#也可以替换为自己的预训练数据
#checkpoint=
#'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth'
#),
#初始输入数据的纬度
in_channels=4),
#解码阶段:重点是修改对应的类别数以及损失函数
decode_head=dict(num_classes=4)
)
8.优化器
# 优化器设置,可以设置多个优化,mmlab支持设置多个优化器
optimizer=dict(
type='AdamW',
lr=0.0001,
weight_decay=0.05,
eps=1e-08,
betas=(
0.9,
0.999,
))
optim_wrapper=dict(
type='OptimWrapper',
optimizer=optimizer,
clip_grad=None,
paramwise_cfg=dict(norm_decay_mult=0.0))
#学习率设置,设置多个学习率,begin=0,end=50,设置学习率的应用范围
param_scheduler=[
#学习率线性预热
dict(type='LinearLR',start_factor=0.001,by_epoch=False,begin=0,end=50),
#'PolyLR'调整学习率
dict(type='PolyLR',eta_min=0.0001,power=0.9,begin=50,end=160000,by_epoch=False),]
9.训练采用的参数设置,包括训练方式(迭代 or epoch)
train_cfg=dict(
#验证、迭代设置
type='IterBasedTrainLoop',max_iters=160000,val_interval=2000)
val_cfg=dict(type='ValLoop')
test_cfg=dict(type='TestLoop')
10.默认采用的钩子函数,包括训练日志设置、参数保存设置、以及可视化相关设置等
#https://mmsegmentation.readthedocs.io/zh-cn/latest/advanced_guides/engine.html?highlight=IterBasedTrainLoop
default_hooks=dict(
#计算迭代时间
timer=dict(type='IterTimerHook'),
#打印日志间隔
logger=dict(type='LoggerHook',interval=50,log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
#权重文件保存
checkpoint=dict(type='CheckpointHook',by_epoch=False,interval=2000,save_best='mIoU',max_keep_ckpts=4,save_last=True),
sampler_seed=dict(type='DistSamplerSeedHook'),
#可视化验证集与测试集结果;draw=True
visualization=dict(type='SegVisualizationHook',draw=True,interval=2000))
auto_scale_lr=dict(enable=False,base_batch_size=16)
11.训练与测试数据流:数据增强
train_pipeline=[
dict(type='LoadImageFromFile',imdecode_backend='tifffile'),
dict(type='LoadAnnotations',reduce_zero_label=False),
dict(type='Resize',scale=(512,384),ratio_range=(0.5,2.0,),keep_ratio=True),
dict(type='RandomCrop',crop_size=(
384,
384,
),cat_max_ratio=0.75),
dict(type='RandomFlip',prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs'),
]
# 测试、验证阶段数据流
test_pipeline=[
dict(type='LoadImageFromFile',imdecode_backend='tifffile'),
dict(type='Resize',scale=(
384,
384,
),keep_ratio=True),
dict(type='LoadAnnotations',reduce_zero_label=False),
dict(type='PackSegInputs'),
]
#tta_model数据流
tta_pipeline=[
dict(type='LoadImageFromFile',backend_args=None),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize',scale_factor=0.5,keep_ratio=True),
dict(type='Resize',scale_factor=0.75,keep_ratio=True),
dict(type='Resize',scale_factor=1.0,keep_ratio=True),
dict(type='Resize',scale_factor=1.25,keep_ratio=True),
dict(type='Resize',scale_factor=1.5,keep_ratio=True),
dict(type='Resize',scale_factor=1.75,keep_ratio=True),
],
[
dict(type='RandomFlip',prob=0.0,direction='horizontal'),
dict(type='RandomFlip',prob=1.0,direction='horizontal'),
],
[
dict(type='LoadAnnotations'),
],
[
dict(type='PackSegInputs'),
],
]),
]
12.训练数据加载器:train_dataloader
#这里使用了数据拼接'ConcatDataset',将不同文件夹下的数据进行整合,用于模型训练
train_dataloader=dict(
batch_size=8,
num_workers=4,
#加速训练设置
persistent_workers=True,
#数据采样方式,是否进行数据随机打乱
sampler=dict(type='InfiniteSampler',shuffle=True),
#数据拼接
dataset=dict(
#数据集类型,这里定义了拼接类型
type='ConcatDataset',
datasets=[
#数据集1
dict(
type='GLCDataset',
data_prefix=dict(
img_path='D:\\LanduseDataset\\train33\\train\\images',
seg_map_path='D:\\LanduseDataset\\train33\\train\\labels'),
pipeline=[
dict(
type='LoadImageFromFile',imdecode_backend='tifffile'),
# 背景以0作为样本
dict(type='LoadAnnotations',reduce_zero_label=False),
dict(type='Resize',scale=(512,384,),keep_ratio=True),
dict(type='RandomCrop',crop_size=(384,384,),cat_max_ratio=0.75),
dict(type='RandomFlip',prob=0.5),
dict(type='PackSegInputs'),]),
#数据集2
dict(
type='GLCDataset',
data_prefix=dict(
img_path='D:\\LanduseDataset\\train11\\train\\images',
seg_map_path='D:\\LanduseDataset\\train11\\train\\labels'),
pipeline=[
dict(type='LoadImageFromFile',imdecode_backend='tifffile'),
dict(type='LoadAnnotations',reduce_zero_label=False),
dict(type='Resize',scale=(512,384,),keep_ratio=True),
dict(type='RandomCrop',crop_size=(384,384,),cat_max_ratio=0.75),
dict(type='RandomFlip',prob=0.5),
dict(type='PackSegInputs'),
]),
]))
#验证集数据加载器
val_dataloader=dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler',shuffle=False),
dataset=dict(
type='GLCDataset',
#data_root='',
data_prefix=dict(
img_path='D:\\LanduseDataset\\train11\\val\\images',
seg_map_path='D:\\LanduseDataset\\train11\\val\\labels'),
pipeline=[
dict(type='LoadImageFromFile',imdecode_backend='tifffile'),
dict(type='Resize',scale=(
384,
384,
),keep_ratio=True),
dict(type='LoadAnnotations',reduce_zero_label=False),
dict(type='PackSegInputs'),
]))
13.精度评价需要计算的指标,可选项包括 ‘mIoU’、’mDice’ 和 ‘mFscore’。
val_evaluator=dict(
type='IoUMetric',iou_metrics=[
'mIoU',
])
test_evaluator=dict(
type='IoUMetric',iou_metrics=[
'mIoU',
])