Detectron 之 config.py 文件参数 快速跑自己模型 Mmdetection

在这里插入图片描述

Detectron:

Detectron
1、参数理解
中文参数说明
官方参数说明
多GPU配置

2、格式转换
训练集格式转换
coco格式
roidb格式

3、训练流程
训练流程1
训练流程2

4、其他
添加垂直翻转
faster rcnn流程:附带部分参数解释
Faster-RCNN算法精读:最后一段对理解参数有帮助

Mmdetection:

mmdetection采用注册所有模块,然后使用配置文件构建模块的方式,因此查找整个流程时不易。
mmdetection数据增强是一次一张图片
Mmdetection
mmdetection流程和detectron相似
参数解释
测试时对增强图片的逆处理流程:
1、增强transform位置
2、构建数据集BuildDataset,应用测试增强
3、构建数据集加载器BuildDataLoader
4、模型测试流程:是走simple test还是aug test
5、具体到某个模型Cascade Rcnn
6、simple的逆处理:bbox rescale
7、验证时调用coco接口计算指标

配置详解:
基础参数就不做介绍了

# model settings
model = dict(
	# 模块路径:mmdet/models/detectors
    type='CascadeRCNN',
    num_stages=3,
    pretrained='open-mmlab://resnext101_32x4d',
    backbone=dict(
    	# 模块路径: mmdet/models/backbones
        type='ResNeXt',
        depth=101,
        groups=32,
        base_width=4,
        num_stages=4,
        # 除了4个大的stage,底层还有个7*7卷积层
        out_indices=(0, 1, 2, 3),
        # frozen 7*7卷积层 + out_indices为1-1=0的stage
        frozen_stages=1,
        style='pytorch'),
    neck=dict(
    	# 模块路径:mmdet/models/necks
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
    	# 模块路径:mmdet/models/anchor_heads
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_scales=[8],    # 视数据集目标面积而定
        anchor_ratios=[0.5, 1.0, 2.0],   # 可根据数据集聚类统计
        anchor_strides=[4, 8, 16, 32, 64],  # 框大小 : (4*8) ** 2
        target_means=[.0, .0, .0, .0],
        target_stds=[1.0, 1.0, 1.0, 1.0],
        loss_cls=dict(
        	# rpn分类只做前背景0-1分类
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),  
            # rpn回归只做目标回归,背景0损失 
        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
    bbox_roi_extractor=dict(
    	# 模块路径: mmdet/models/roi_extractors
        type='SingleRoIExtractor',
        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
        out_channels=256,
        featmap_strides=[4, 8, 16, 32]),
    bbox_head=[
    	# 模块路径 : mmdet/models/bbox_heads
        dict(
            type='SharedFCBBoxHead',
            num_fcs=2,
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=81,   # 类别数+1(背景),代码会将anno中的categories按顺序进行映射,因此预测时可能需要反向映射
            target_means=[0., 0., 0., 0.],
            target_stds=[0.1, 0.1, 0.2, 0.2],
            reg_class_agnostic=True,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
    ],
# model training and testing settings
train_cfg = dict(
    rpn=dict(
        assigner=dict(
        	# 模块路径:mmdet/core/bbox/assigners
        	# 注意该模块是在rpn_head.loss里面调用的mmdet/core/anchor/anchor_target.py实现的
        	# rpn的loss训练的bbox没有使用NMS
            type='MaxIoUAssigner',
            pos_iou_thr=0.7,     # 可细致调参
            neg_iou_thr=0.3,
            min_pos_iou=0.3,
            ignore_iof_thr=-1),
        sampler=dict(
        	# 模块路径:mmdet/core/bbox/samplers
        	# 首先预采集256*0.5(不足时就低于)的正样本,然后256-已采集正样本数量得到负样本采集数量进行采集
        	# 对于不足0.5时,可以使用参数neg_pos_ub降低负样本采集数量
            type='RandomSampler',
            num=256,
            pos_fraction=0.5,
            neg_pos_ub=-1,
            add_gt_as_proposals=False),
        allowed_border=0,
        # 可为前背景样本loss赋予权重
        pos_weight=-1,
        debug=False),
    rpn_proposal=dict(
    	# 这一步是一阶段到二阶段的过渡,使用NMS提取兴趣区域
        nms_across_levels=False,
        nms_pre=2000,
        nms_post=2000,
        max_num=2000,
        nms_thr=0.7,    # 可尝试调参
        min_bbox_size=0),
    rcnn=[
        dict(
            assigner=dict(
            	# 该模块在detectors里面调用
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,   # 可细致调参
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                ignore_iof_thr=-1),
            sampler=dict(
            	# 该模块在detectors里面调用
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            pos_weight=-1,    # 可在此处为样本loss加权
            debug=False),
    ],
    stage_loss_weights=[1, 0.5, 0.25])
test_cfg = dict(
    rpn=dict(
    	# 测试时的兴趣区域提取规则
        nms_across_levels=False,
        nms_pre=1000,
        nms_post=1000,
        max_num=1000,
        nms_thr=0.7,  # 可尝试调参
        min_bbox_size=0),
    rcnn=dict(
    	# 测试时对最终预测的bbos提取规则
    	# bbox分类得分阈值,可细致调参
        score_thr=0.05,
        # NMS的ios阈值,可细致调参
        nms=dict(type='nms', iou_thr=0.5),
        # 单张图片最大输出目标数量
        max_per_img=100))

其他路径:
验证时加入的hook
数据增强
控制样本的哪些信息传入模型
anns中categories的顺序与label位置对应

目标检测模块结构

Object Detection and Classification using R-CNNs

比赛技巧

Bag of Freebies for Training Object Detection Neural Networks
目标检测精度提升之奇技淫巧
目标检测调优技巧:《Bag of Freebies for Training Object Detection Neural Networks》论文笔记
我这两年的目标检测
目标检测Tricks
工业视觉中的目标检测——兼谈天池大赛优胜方案
Crowdhuman人体检测比赛第一名经验总结
津南数字制造算法挑战赛【赛场二】总决赛 亚军比赛攻略_这是我的马队
目标检测系列三:奇技淫巧

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
`config.py` 文件一般用于存放程序的配置信息,它是一个 Python 模块,可以在其它模块中导入并使用。在 `config.py` 文件中,我们可以定义一些常量、变量、函数等,用于存放程序运行所需的配置信息,比如数据库连接信息、API 接口地址、文件路径等等。通过将这些配置信息单独存放在一个文件中,我们可以方便地对它们进行修改和管理,从而提高代码的可维护性和可扩展性。 在 `config.py` 文件中,通常会定义一个 `get_config()` 函数,用于读取配置信息并返回一个字典或对象,以便其它模块可以方便地调用这些配置信息。同时,我们也可以在 `config.py` 文件中添加一些注释,以便于理解和维护。例如: ```python # 数据库连接信息 DB_HOST = 'localhost' DB_PORT = 3306 DB_USER = 'root' DB_PASSWD = '123456' DB_NAME = 'mydb' # 文件路径 LOG_PATH = '/var/log/myapp.log' CONFIG_PATH = '/etc/myapp/config.ini' def get_config(): # 读取配置信息并返回一个字典 return { 'db_host': DB_HOST, 'db_port': DB_PORT, 'db_user': DB_USER, 'db_passwd': DB_PASSWD, 'db_name': DB_NAME, 'log_path': LOG_PATH, 'config_path': CONFIG_PATH, } ``` 这样,在其它模块中,我们就可以通过 `get_config()` 函数来获取这些配置信息。例如: ```python # 导入 config 模块 from config import get_config # 获取配置信息 config = get_config() # 打印数据库连接信息 print(f"database host: {config['db_host']}") print(f"database port: {config['db_port']}") print(f"database user: {config['db_user']}") print(f"database passwd: {config['db_passwd']}") print(f"database name: {config['db_name']}") # 打印文件路径 print(f"log path: {config['log_path']}") print(f"config path: {config['config_path']}") ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值