mmsegmentation参数配置实用详解

系列文章目录



前言

mmsegmentation实用参数配置文件,主要包括基础环境配置,学习率调整,优化器、数据流;对模型部分进行了删减与简化


一、mmsegmentation

1.基础的环境配置


#default_scope:搜索所有注册模块的起点
```cdefault_scope='mmseg'
#环境配置
#是否启用 cudnn_benchmark
env_cfg=dict(
cudnn_benchmark=True,
#设置多进程参数
mp_cfg=dict(mp_start_method='fork',opencv_num_threads=0),
# 设置分布式参数
dist_cfg=dict(backend='nccl'))

2.可视化训练-启用Tensorboard

#可视化训练还包括WANDB

visualizer=dict(
type='SegLocalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend'),
],
name='visualizer')

3.日志

log_processor=dict(by_epoch=False)
log_level='INFO'

3.是否启用断点训练:应该只支持epoch训练模式

load_from=None
#断点训练
resume=False
#是否开启混合精度训练
#amp = False
#工作路径
work_dir='D:\\Changdong\\mmsegmentation\\work_dir'

4.测试增强,在测试阶段进行数据增强,比如多尺度数据增强(TEST TIME AUGMENTATION)

tta_model=dict(type='SegTTAModel')

5.自定义数据类型,一般对应着改一下AD20k的数据读取就可以,修改类别与读取文件后缀

dataset_type='GLCDataset'
data_root='D:\\LanduseDataset'
crop_size=(
384,
384,
)

6. 数据预处理

data_preprocessor=dict(
type='SegDataPreProcessor',
mean=[
123.675,
116.28,
103.53,
103.53,
],
std=[58.395,57.12,57.375,57.375,
],
#多波段遥感图像读取不能进行彩色变换
bgr_to_rgb=False,
pad_val=0,
seg_pad_val=255,
size=(384,384,))

7. 模型,根据segmentors文件夹下定义的语义分割(模型)框架,如编码解码结构

num_classes=4
model=dict(
	type='EncoderDecoder',
	data_preprocessor=data_preprocessor,
	#孪生TF主干网络
	#在主干网络中需要设置channel
	backbone=dict(
	#冻结训练设置:0-3可以进行冻结设置
	frozen_stages=-1,
	# 模型采用预训练权重进行初始化
	#init_cfg=dict(
	#type='Pretrained',
	#也可以替换为自己的预训练数据
	#checkpoint=
#'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth'
	#),
	#初始输入数据的纬度
	in_channels=4),
	#解码阶段:重点是修改对应的类别数以及损失函数
	decode_head=dict(num_classes=4)
	)

8.优化器

# 优化器设置,可以设置多个优化,mmlab支持设置多个优化器
optimizer=dict(
	type='AdamW',
	lr=0.0001,
	weight_decay=0.05,
	eps=1e-08,
	betas=(
	0.9,
	0.999,
	))
optim_wrapper=dict(
type='OptimWrapper',
optimizer=optimizer,
clip_grad=None,
paramwise_cfg=dict(norm_decay_mult=0.0))
#学习率设置,设置多个学习率,begin=0,end=50,设置学习率的应用范围
param_scheduler=[
#学习率线性预热
dict(type='LinearLR',start_factor=0.001,by_epoch=False,begin=0,end=50),
#'PolyLR'调整学习率
dict(type='PolyLR',eta_min=0.0001,power=0.9,begin=50,end=160000,by_epoch=False),]

9.训练采用的参数设置,包括训练方式(迭代 or epoch)

train_cfg=dict(
#验证、迭代设置
type='IterBasedTrainLoop',max_iters=160000,val_interval=2000)
val_cfg=dict(type='ValLoop')
test_cfg=dict(type='TestLoop')

10.默认采用的钩子函数,包括训练日志设置、参数保存设置、以及可视化相关设置等

#https://mmsegmentation.readthedocs.io/zh-cn/latest/advanced_guides/engine.html?highlight=IterBasedTrainLoop

default_hooks=dict(
#计算迭代时间
timer=dict(type='IterTimerHook'),
#打印日志间隔
logger=dict(type='LoggerHook',interval=50,log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
#权重文件保存
checkpoint=dict(type='CheckpointHook',by_epoch=False,interval=2000,save_best='mIoU',max_keep_ckpts=4,save_last=True),
sampler_seed=dict(type='DistSamplerSeedHook'),
#可视化验证集与测试集结果;draw=True
visualization=dict(type='SegVisualizationHook',draw=True,interval=2000))
auto_scale_lr=dict(enable=False,base_batch_size=16)

11.训练与测试数据流:数据增强

train_pipeline=[
	dict(type='LoadImageFromFile',imdecode_backend='tifffile'),
	dict(type='LoadAnnotations',reduce_zero_label=False),
	dict(type='Resize',scale=(512,384),ratio_range=(0.5,2.0,),keep_ratio=True),
	dict(type='RandomCrop',crop_size=(
	384,
	384,
	),cat_max_ratio=0.75),
	dict(type='RandomFlip',prob=0.5),
	dict(type='PhotoMetricDistortion'),
	dict(type='PackSegInputs'),
]

# 测试、验证阶段数据流
test_pipeline=[
	dict(type='LoadImageFromFile',imdecode_backend='tifffile'),
	dict(type='Resize',scale=(
	384,
	384,
	),keep_ratio=True),
	dict(type='LoadAnnotations',reduce_zero_label=False),
	dict(type='PackSegInputs'),
]
#tta_model数据流
tta_pipeline=[
	dict(type='LoadImageFromFile',backend_args=None),
	dict(
	type='TestTimeAug',
	transforms=[
	[
	dict(type='Resize',scale_factor=0.5,keep_ratio=True),
	dict(type='Resize',scale_factor=0.75,keep_ratio=True),
	dict(type='Resize',scale_factor=1.0,keep_ratio=True),
	dict(type='Resize',scale_factor=1.25,keep_ratio=True),
	dict(type='Resize',scale_factor=1.5,keep_ratio=True),
	dict(type='Resize',scale_factor=1.75,keep_ratio=True),
	],
	[
	dict(type='RandomFlip',prob=0.0,direction='horizontal'),
	dict(type='RandomFlip',prob=1.0,direction='horizontal'),
	],
	[
	dict(type='LoadAnnotations'),
	],
	[
	dict(type='PackSegInputs'),
	],
	]),
]

12.训练数据加载器:train_dataloader

#这里使用了数据拼接'ConcatDataset',将不同文件夹下的数据进行整合,用于模型训练
train_dataloader=dict(
batch_size=8,
num_workers=4,
#加速训练设置
persistent_workers=True,
#数据采样方式,是否进行数据随机打乱
sampler=dict(type='InfiniteSampler',shuffle=True),
#数据拼接
dataset=dict(
	#数据集类型,这里定义了拼接类型
	type='ConcatDataset',
	datasets=[
	#数据集1
	dict(
	type='GLCDataset',
	data_prefix=dict(
	img_path='D:\\LanduseDataset\\train33\\train\\images',
	seg_map_path='D:\\LanduseDataset\\train33\\train\\labels'),
	pipeline=[
	dict(
		type='LoadImageFromFile',imdecode_backend='tifffile'),
		# 背景以0作为样本
		dict(type='LoadAnnotations',reduce_zero_label=False),
		dict(type='Resize',scale=(512,384,),keep_ratio=True),
		dict(type='RandomCrop',crop_size=(384,384,),cat_max_ratio=0.75),
		dict(type='RandomFlip',prob=0.5),
		dict(type='PackSegInputs'),]),
#数据集2
	dict(
		type='GLCDataset',
		data_prefix=dict(
		img_path='D:\\LanduseDataset\\train11\\train\\images',
		seg_map_path='D:\\LanduseDataset\\train11\\train\\labels'),
		pipeline=[
		dict(type='LoadImageFromFile',imdecode_backend='tifffile'),
		dict(type='LoadAnnotations',reduce_zero_label=False),
		dict(type='Resize',scale=(512,384,),keep_ratio=True),
		dict(type='RandomCrop',crop_size=(384,384,),cat_max_ratio=0.75),
		dict(type='RandomFlip',prob=0.5),
		dict(type='PackSegInputs'),
		]),
	]))
#验证集数据加载器
val_dataloader=dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler',shuffle=False),
dataset=dict(
type='GLCDataset',
#data_root='',
data_prefix=dict(
		img_path='D:\\LanduseDataset\\train11\\val\\images',
		seg_map_path='D:\\LanduseDataset\\train11\\val\\labels'),
		pipeline=[
		dict(type='LoadImageFromFile',imdecode_backend='tifffile'),
		dict(type='Resize',scale=(
		384,
		384,
		),keep_ratio=True),
		
		dict(type='LoadAnnotations',reduce_zero_label=False),
		dict(type='PackSegInputs'),
		]))

13.精度评价需要计算的指标,可选项包括 ‘mIoU’、’mDice’ 和 ‘mFscore’。

val_evaluator=dict(
type='IoUMetric',iou_metrics=[
'mIoU',
])
test_evaluator=dict(
type='IoUMetric',iou_metrics=[
'mIoU',
])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云朵不吃雨

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值