引言:
cascade_rcnn.py
文件在moels/detections
文件夹下。本次对文件cascade_rcnn.py
的代码解读,是根据py配置文件configs/cascade_rcnn_r50_fpn_1x.py
的数据信息进行讲解的。
moels/detectionscascade_rcnn.py文件中
主要的内容如下:
- __init__() :module的构造函数。
- init_weights() :backbone为cascade rcnn的初始化权重方法,在__init__()调用进行初始化。
- extract_feat() :提取img特征,主要实现了backbone和neck的forward()的前向计算。
- forward_train() :在这里实现层之间的连接关系,其实就是所谓的前向传播。当执行model(x)(该model为module的子类)的时候,底层自动调用forward方法计算结果。
- simple_test() :检测过程的前向传播forward调用的函数,通过最原始的nn.Module到父类BaseDetector的forward,继续由底层 ,层层向上调用到这里。
- aug_test() :Test with augmentations。
- show_result() :
共七个部分,本篇文章主要对前四个部分的代码精度,这四个步骤中的__init__()
和forward_train()
是module类的最主要的两个部分,也是定义网络的最关键的部分。
自定义一个模型就是通过继承nn.Module类来实现,在__init__()
构造函数中申明各个层的定义,在forward()
中实现层之间的连接关系,实际上就是前向传播的过程。
注:后面三个部分,博主后续会继续阅读代码,在对这三个部分进行补充。
首先,看本篇讲解时,先了解一下下篇文章,该文章讲解了创建模型的过程,尤其以detection为例,讲解了mmdetection通过注册表的形式,实例化了类名为DETECTION的Rigistry类,并且在其module_dict属性中,保存了detection的module类,和其对应的类名。通过这篇文章,可以了解mmdetection如何注册和创建模型的。
其次,了解一下torch.nn.module
(有pytorch基础也行,博主刚开始看mmdetection时,没有pytorch一点基础,然后看到forward()函数时,找了好几个文件夹,看他在哪里调用的…,后面才知道,forward()是自定义层的前向计算,自动执行的( 也就是对输入自动进行处理)),推荐下篇文章:
__init__()
@DETECTORS.register_module
#在build_from_cfg()中,实例化detector,然后在通过形参的方式,将类和类名送入了方法register_module中。
class CascadeRCNN(BaseDetector, RPNTestMixin):
# 参数来自cascade_rcnn_r50_fpn_1x.py
def __init__(self,
num_stages, # 3
backbone, # ResNet
neck=None, # FPN
shared_head=None,
rpn_head=None, # RPNHead
bbox_roi_extractor=None, # SingleRoIExtractor
bbox_head=None, # SharedFCBBoxHead * 3 (三阶段)
mask_roi_extractor=None,
mask_head=None,
train_cfg=None, # assigner : MaxIoUAssigner ; sampler : RandomSampler
test_cfg=None, # skip
pretrained=None): # modelzoo://resnet50
assert bbox_roi_extractor is not None
assert bbox_head is not None
super(CascadeRCNN, self).__init__()
self.num_stages = num_stages
self.backbone = builder.build_backbone(backbone) # build backbone and Registry
#同上,创建模型,对各个组件(比如backbone、neck、bbox_head等字典数据,构建成module类)分别创建module类模型
if neck is not None:
self.neck = builder.build_neck(neck)
if rpn_head is not None:
self.rpn_head = builder.build_head(rpn_head)
if shared_head is not None:
self.shared_head = builder.build_shared_head(shared_head)
if bbox_head is not None:
self.bbox_roi_extractor = nn.ModuleList()
#ModuleList() 能够像列表一样索引 , [module1 , module2 , module3 ....]
#type='SingleRoIExtractor'
self.bbox_head = nn.ModuleList()
#SharedFCBBoxHead * 3 ; 三个字典构成list列表,字典的type一样,但是里面的其他字段不一样
if not isinstance(bbox_roi_extractor, list):
bbox_roi_extractor = [
bbox_roi_extractor for _ in range(num_stages)
# cascade rcnn, 1 stage + 3 stage , 3 include 3 times detection
]
if not isinstance(bbox_head, list): # bbox_head is list, so skip
bbox_head = [bbox_head for _ in range(num_stages)]
assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages