本文主要和大家分享一下我个人对于Faster R-CNN代码的解读,希望可以帮助到大家!整体框架train.py
faster_rcnn.py
rpn.py
generate_anchors.py
anchor_target.py
proposal.py
proposal_target.py
1.整体框架
在详细介绍代码细节之前,我们可以先理清Faster RCNN的整体框架和整个训练过程。整个过程涉及到三个文件:train.py,faster_rcnn.py和rpn.py。在这里,我们只需要理清主线,所以我简化了这3个文件里的代码,完整的代码可以参考我的GitHub:
1.1 train.py
train.py就是我们训练时运行的文件,主要作用就是调用FasterRCNN网络得到分类和检测结果,然后计算loss,再用梯度下降优化网络,大致可以总结为以下5个步骤:加载训练数据
定义模型FasterRCNN
将数据输入到模型中,并得到模型的输出
根据模型的输出,计算loss,loss就是faster_rcnn的分类loss和回归loss,以及rpn的分类loss和回归loss的均值和
进行反向传播
在这几步中,最核心的是模型,所以下一步就要去看一下FasterRCNN是如何实现的。
1.2 faster_rcnn.py
在faster_rcnn.py中主要定义了FasterRCNN这个类,在这个类中构建了Faster RCNN整个网络,也很清楚的给出了整个流程,具体包括以下步骤:首先使用backbone网络提取输入图片的特征
使用RPN网络来提取rois
如果是训练,得到proposal_target,即分类和回归的ground truth,后续计算faster rcnn的loss时需要用到
使用roi_pooling得到rois的feature map
使用classifier提取特征
使用faster_rcnn_cls得到分类结果
使用faster_rcnn_reg得到回归结果
如果是训练,计算分类loss
如果是训练,计算回归loss
在FasterRCNN中我们需要重点关注的是rpn网络提取proposal,以及proposal_target函数。其中rpn是流程的主线,所以我们接下来先讲解rpn网络,对于proposal_target,我们之后会详细讲解。
1.3 rpn.py
RPN网络的结构是在rpn.py中实现的,主要作用就是计算anchor进行分类和回归结果,然后根据分类和回归结果调用proposal函数得到proposals(rois),大致可以总结为以下几步:对于输入的feature map先用rpn_conv进行卷积
然后使用rpn_cls卷积层得到分类结果
同时使用rpn_reg卷积层得到回归结果
然后之后再调用proposal函数得到proposals(rois)
如果是训练过程,那么使用调用anchor_target产生rpn网络中分类和回归的ground truth值,之后在计算rpn的loss时会用到
如果是训练过程,那么计算分类loss
如果是训练过程,那么计算回归loss
RPN完成之后,返回rois和loss,这就相当于完成上述FasterRCNN中的第2步了,然后在FasterRCNN中接着完成剩余的步骤,然后FasterRCNN会返回结果,这就完成了上述Train中的第3步了,然后在Train中接着完成其他的步骤便结束了一个完整的流程。
流程中有几个重要的函数需要详细讲解,包括anchor_target,proposal和proposal_ta