模型配置
原mmdetection中的模型配置
# 模型的配置文件
model = dict(
type='RPN', # 模型类型
# 数据预处理器的类型和参数设置
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],# 均值
std=[58.395, 57.12, 57.375],# 标准差
bgr_to_rgb=True, # 是否将图像由BGR格式转为RGB格式
pad_size_divisor=32),# Pad的大小除数
# 主干网络的类型和参数设置
backbone=dict(
type='ResNet',# 使用ResNet-50
depth=50,
num_stages=4, # 层数
out_indices=(0, 1, 2, 3), # 输出特征图的层数
frozen_stages=1, # 冻结的层数
norm_cfg=dict(type='BN', requires_grad=True), # 归一化类型
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
# RPN网络中neck层的类型和参数设置
neck=dict(
type='FPN',# 使用的是FPN
in_channels=[256, 512, 1024, 2048],# 输入特征图的通道数
out_channels=256,# 输出特征图的通道数
num_outs=5),# 数量
# RPN网络中head层的类型和参数设置
rpn_head=dict(
type='RPNHead',
in_channels=256,# 输出特征图的通道数
feat_channels=256, # 输出特征图的通道数
# anchor生成器
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
# bbox编码器
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
# 分类损失
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
# 回归损失
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
如果想要用
FCOS
去训练,则可以设置如下过程
model = dict(
type='FCOS',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
bbox_head=dict(
type='FCOSHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512), (512, -1)),
center_sampling=True,
center_sample_radius=1.5,
norm_on_bbox=True,
centerness_on_reg=True,
dcn_on_last_conv=False,
conv_bias=True,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
数据集和评测器配置
在使用
执行器
进行训练、测试、验证时,我们需要配置
Dataloader
。构建数据
dataloader
需要设置数
据集(
dataset
)和数据处理流程(
data pipeline
)。 由于这部分的配置较为复杂,我们使用中间变量
来简化
dataloader
配置的编写。
训练数据流程
测试数据处理流程
训练数据配置
![](https://i-blog.csdnimg.cn/blog_migrate/f4a89ec0ec22d828ed8c10bb6b3ea469.png)
测试数据配置
![](https://i-blog.csdnimg.cn/blog_migrate/cdc0f4e5045d2eaa13e69cbc43c8d1fe.png)
评测器
评测器
用于计算训练模型在验证和测试数据集上的指标。评测器的配置由一个或一组评价指标
(
Metric
)配置组成:
![](https://i-blog.csdnimg.cn/blog_migrate/e7b3afa2c43a096afd655903b2c61c19.png)
训练和测试的配置
MMEngine
的
Runner
使用
Loop
来控制训练,验证和测试过程。 用户可以使用这些字段设置最大训练轮次和验证间隔。
![](https://i-blog.csdnimg.cn/blog_migrate/8a66ca017d439a0421ed9daadce43e01.png)
优化相关配置
optim_wrapper
是配置优化相关设置的字段。优化器封装(
OptimWrapper
)不仅提供了优化器的功能,还支持梯度裁剪、混合精度训练等功能。更多内容请看
优化器封装教程
。
![](https://i-blog.csdnimg.cn/blog_migrate/fd3e7349ae2e27348506b0d4658e18b0.png)
param_scheduler
字段用于配置参数调度器(
Parameter Scheduler
)来调整优化器的超参数(例如学习率和动量)。 用户可以组合多个调度器来创建所需的参数调整策略。 在
参数调度器教程
和
参数调度器 API
文档
中查找更多信息
![](https://i-blog.csdnimg.cn/blog_migrate/4791ecad07edf2cc97f091ec54a6e56c.png)
钩子配置
用户可以在训练、验证和测试循环上添加钩子,以便在运行期间插入一些操作。配置中有两种不同的钩子字段,一种是 default_hooks
,另一种是
custom_hooks
。
default_hooks
是一个字典,用于配置运行时必须使用的钩子。这些钩子具有默认优先级,如果未设置,runner
将使用默认值。如果要禁用默认钩子,用户可以将其配置设置为
None
。更多内容请看
钩子教程
。
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='DetVisualizationHook'))
运行相关配置
default_scope = 'mmdet' # 默认的注册器域名,默认从此注册器域中寻找模块。请参考
https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/registry.html
env_cfg = dict(
cudnn_benchmark=False, # 是否启用 cudnn benchmark
mp_cfg=dict( # 多进程设置
mp_start_method='fork', # 使用 fork 来启动多进程。'fork' 通常比 'spawn' 更
快,但可能存在隐患。请参考 https://github.com/pytorch/pytorch/issues/1355
opencv_num_threads=0), # 关闭 opencv 的多线程以避免系统超负荷
dist_cfg=dict(backend='nccl'), # 分布式相关设置
)
vis_backends = [dict(type='LocalVisBackend')] # 可视化后端,请参考
https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/visualization.ht
ml
visualizer = dict(
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_processor = dict(
type='LogProcessor', # 日志处理器用于处理运行时日志
window_size=50, # 日志数值的平滑窗口
by_epoch=True) # 是否使用 epoch 格式的日志。需要与训练循环的类型保存一致。
log_level = 'INFO' # 日志等级
load_from = None # 从给定路径加载模型检查点作为预训练模型。这不会恢复训练。
resume = False # 是否从 `load_from` 中定义的检查点恢复。 如果 `load_from` 为 None,它
将恢复 `work_dir` 中的最新检查点。