FasterRCNN源码解析——网络框架
文章目录
前言
主要是对哔哩哔哩up主霹雳吧啦Wz所讲解的视频Faster RCNN源码解析(pytorch)进行一个总结回顾,以加深印象。
一、FasterRCNN流程图
黄色虚线框代表只有在训练过程中才有的部分
二、框架
在faster_rcnn/network_files/faster_rcnn_framework.py脚本中
1. FasterRCNNBase
类
1.__init__
在初始化函数当中我们会传入backbone, rpn, roi_heads, transform
四个参数分别对应框架图的四个部分
def __init__(self, backbone, rpn, roi_heads, transform):
super(FasterRCNNBase, self).__init__()
self.transform = transform
self.backbone = backbone
self.rpn = rpn
self.roi_heads = roi_heads
# used only on torchscript mode
self._has_warned = False
2.forward
传入的是 图片以及其标签,也就是读取解析PASCAL VOC2012数据集一文中的__getitem__
方法输出的image
和target
(type: (List[Tensor], list[Dict[Tensor]]))
这里输入的images的大小都是不同的,后面会进行预处理将这些图片放入同样大小的tensor中打包成一个batch
def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
"""
Arguments:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
1.对数据进行数据预处理
两个操作1,标准化 2.限定输入图像最小边长和最大边长