怎么看mmdetection版本_入门mmdetection(壹)

  一直以来都在玩caffe,沉浸在写反传的快乐中(真的快乐么???),然后睁眼看世界发现pytorch已经火得不能再火了,学术界用了都说好,一直想要深入学一学,但是没啥契机和时间。最近复习检测相关的东西,发现mmdetection这个框架用起来很方便,新算法开源得也快,就尝试着读一读里面的代码,也算是借着它一起入门一下pytorch。

  自己比较健忘,就在这边记录一下浏览过的代码好了(不会每一行都细读,以梳理为主)。

  源码写得真香,路径在:open-mmlab/mmdetection

  看代码之前想多说一句,以前看检测新出一个方法就刷下论文,没有成体系。但是现在这个框架将检测拆解模块化为backbone(躯干),neck(脖子), head(头部),无论是单阶段还是双阶段。所以现在看论文把已有方法想象成一个病人,把自己想象成一个执刀的外科医生,代入感很强,线索清晰,体系自成。

  首先从tools里面的train.py入手。

  这个main函数其实就做了四件事:

  1. 设定和读取各种配置;
  2. 创建模型;
  3. 创建数据集;
  4. 将模型,数据集和配置传进训练函数;
# 第一件事
args = parse_args()
cfg = Config.fromfile(args.config)

# 第二件事 创建模型
# 第一个参数cfg.model模型配置里面必须要有一个种类type,包括经典的算法如Faster RCNN, MaskRCNN等
# 其次,还包含几个部分,如backbone, neck, head
# backbone有深度,stage等信息,如resnet50对应着3,4,6,3四个重复stages
# neck一般FPN(feature pyramid network),需要指定num_outs几个输出之类的信息(之后会看到)
# head 就是具体到上层rpn_head, shared_head, bbox_head之类的
# 如果不清楚我们可以去某个config里面验证一下
# 返回的是一个类的对象,详见下面的build函数
model = build_detector(
    cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

# 第三件事
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:  #是否添加验证集
    datasets.append(build_dataset(cfg.data.val))
...

# 第四件事
train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)

上面说道,build_detector返回的是一个type名字对应的类的对象,我们先来看一个config:

model = dict(
    type='FasterRCNN',  # 看到没,就是这个type!!!
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_scales=[8],
        anchor_ratios=[0.5, 1.0, 2.0],
        anchor_strides=[4, 8, 16, 32, 64],
        target_means=[.0, .0, .0, .0],
        target_stds=[1.0, 1.0, 1.0, 1.0],
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
    bbox_roi_extractor=dict(
        type='SingleRoIExtract
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值