一直以来都在玩caffe,沉浸在写反传的快乐中(真的快乐么???),然后睁眼看世界发现pytorch已经火得不能再火了,学术界用了都说好,一直想要深入学一学,但是没啥契机和时间。最近复习检测相关的东西,发现mmdetection这个框架用起来很方便,新算法开源得也快,就尝试着读一读里面的代码,也算是借着它一起入门一下pytorch。
自己比较健忘,就在这边记录一下浏览过的代码好了(不会每一行都细读,以梳理为主)。
源码写得真香,路径在:open-mmlab/mmdetection
看代码之前想多说一句,以前看检测新出一个方法就刷下论文,没有成体系。但是现在这个框架将检测拆解模块化为backbone(躯干),neck(脖子), head(头部),无论是单阶段还是双阶段。所以现在看论文把已有方法想象成一个病人,把自己想象成一个执刀的外科医生,代入感很强,线索清晰,体系自成。
首先从tools里面的train.py入手。
这个main函数其实就做了四件事:
- 设定和读取各种配置;
- 创建模型;
- 创建数据集;
- 将模型,数据集和配置传进训练函数;
# 第一件事
args = parse_args()
cfg = Config.fromfile(args.config)
# 第二件事 创建模型
# 第一个参数cfg.model模型配置里面必须要有一个种类type,包括经典的算法如Faster RCNN, MaskRCNN等
# 其次,还包含几个部分,如backbone, neck, head
# backbone有深度,stage等信息,如resnet50对应着3,4,6,3四个重复stages
# neck一般FPN(feature pyramid network),需要指定num_outs几个输出之类的信息(之后会看到)
# head 就是具体到上层rpn_head, shared_head, bbox_head之类的
# 如果不清楚我们可以去某个config里面验证一下
# 返回的是一个类的对象,详见下面的build函数
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
# 第三件事
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2: #是否添加验证集
datasets.append(build_dataset(cfg.data.val))
...
# 第四件事
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=args.validate,
logger=logger)
上面说道,build_detector返回的是一个type名字对应的类的对象,我们先来看一个config:
model = dict(
type='FasterRCNN', # 看到没,就是这个type!!!
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtract