FasterRCNN代码解读

原创 2018年01月24日 22:34:39

之前的文章简要介绍了Faster-RCNN等物体检测的算法,本文将从代码角度详细分析介绍Faster-RCNN的实现。本文使用的代码参考了chenyuntc的实现,代码的位置看这里。需要注意的是,本文使用的框架是Pytorch。

图片名称

数据载入

数据载入部分的代码主要见./data/dataset.py中的类DatasetTestDataset

数据载入部分的逻辑如下:

  1. 从VOC数据集中获得img, bbox, label
  2. img, bbox进行放缩(放缩的目的是让图片处于合适的大小,这样预先指定锚框才有意义)
  3. img进行标准化正则处理
  4. 如果是训练阶段,将img翻转以增加训练数据

网络结构

FasterRCNN的网络结构如下图所示:

这里写图片描述

FasterRCNN结构的代码主要见./model.faster_rcnn.py,其结构包含三大部分:

  1. 预训练的CNN模型 decom_vgg16
  2. rpn网络RegionProposalNetwork
  3. roi及以上网络VGG16RoIHead

下面,将以放缩后大小为[1, 3, 600, 800]的图片为例针对每个部分分别介绍。图像类别共计21类(包含背景)。

预训练的CNN模型

该部分代码见./model/vgg16.py

输入:图片,大小[1, 3, 600, 800]
输出:特征图features,大小[1, 512, 37, 50]


其逻辑如下:

  1. 载入预先训练好的CNN模型VGG16。
  2. 将模型拆分为两部分extractor, classifier。其中,extractor的参数固定。
  3. 图片通过extractor可以得到特征图features。根据extractor中池化参数可知图像通过extractor缩小了16倍。

rpn网络

该部分代码见./model/rpn.py

输入:特征图features,大小[1, 512, 37, 50]
输出:

  • rpn_locs:rpn对位置的修正,大小[1, 16650, 4]
  • rpn_scores :rpn判断区域前景背景,大小[1, 16650, 2]
  • rois:rpn筛选出的roi的位置,大小[300, 4]
  • roi_indices:rpn筛选出的roi对应的图片索引,大小[300]
  • anchor:原图像的锚点,大小[16650, 4]

其中,16650是放缩后的图像所产生的所有锚点(37*50*9),每个锚点都对应了一个rp。通过 rpn_scores以及nms可以得到筛选后的大小为300的roi。


其逻辑如下:

  1. 对特征图features以基准长度为16、选择合适的ratiosscales取基准锚点anchor_base。(选择长度为16的原因是图片大小为600*800左右,基准长度16对应的原图区域是256*256,考虑放缩后的大小有128*128,512*512比较合适)
  2. 根据anchor_base在原图上获得anchors
  3. 对特征图features采用卷积得到rpn_locsrpn_scores
  4. 根据anchorsrpn_locs获得修正后的rp
  5. rp进一步修正获得roisroi_indices,修正包括超出边界的部分截断、移除太小的、nms。

roi及以上网络

该部分代码见./model/roi_module.py

输入:

  • features:特征图,大小[1, 512, 37, 50]
  • rois:rpn筛选出的roi的位置,大小[300, 4]
  • roi_indices:rpn筛选出的roi对应的图片索引,大小[300]

输出:

  • roi_cls_locsroi位置的修正,大小[300, 84]
  • roi_scoresroi各类的分数,大小[300, 21]

其逻辑如下:

  1. 通过RoIPooling2D将大小不同的roi变成大小一致,得到pooling后的特征,大小为[300, 512, 7, 7]
  2. 接入预训练的CNN模型引入的classifier
  3. 分别接入全连接得到roi_cls_locsroi_scores

训练

训练部分的代码主要见./trainer/trainer.py中的FasterRCNNTrainer中的train_step函数。

训练部分的核心是loss如何求取。

loss求取前网络的步骤如下:

  1. 预训练CNN特征提取:输入imgextractor获得features
  2. rpn网络得到roi:输入featuresrpn获得rpn_locs, rpn_scores, rois, roi_indices, anchor
  3. 抽样roi:输入roisbboxlabelProposalTargetCreator获得sample_roi, gt_roi_loc, gt_roi_label。该步骤的含义是得到正负例比例和位置合适的roi
  4. head网络得到roi的位置修正与分数:输入features,sample_roi,sample_roi_index得到roi_cls_loc, roi_score

各个loss求取的方式如下:

  1. rpn_loc_loss:已知rpn_loc,需要先根据anchorbbox得到真实的gt_rpn_locgt_rpn_label。该处loss的计算只考虑前景,所以根据rpn_loc,gt_rpn_loc,gt_rpn_label计算L1-LOSS即可。
  2. rpn_cls_loss:根据rpn_scoregt_rpn_label计算二分类的交叉熵即可。
  3. roi_loc_loss:已知roi_loc,在sample roi的过程中已获得gt_roi_loc, gt_roi_label。根据roi_loc,gt_roi_loc,gt_roi_label计算L1-LOSS即可。
  4. roi_cls_loss:根据roi_scoregt_roi_label计算多分类的交叉熵即可。

整体的loss为以上各loss相加求和。

测试

训练部分的代码主要见./model/faster_rcnn.py中的FasterRCNNTrainer中的predict函数。

其步骤如下:

  1. 图片预处理
  2. 预训练CNN特征提取:输入imgextractor获得features
  3. rpn网络得到roi:输入featuresrpn获得rpn_locs, rpn_scores, rois, roi_indices, anchor
  4. head网络得到roi的位置修正与分数:输入features,rois,roi_indices得到roi_cls_loc, roi_score
  5. 得到图片预测的bbox:输入roi_cls_locroi_scorerois,采用nms等方法得到预测的bbox

Faster-RCNN_TF代码解读10:proposal_layer_tf.py

# -------------------------------------------------------- # Faster R-CNN # Copyright (c) 2015 Micro...
  • l297969586
  • l297969586
  • 2017年09月18日 16:31
  • 809

Fast-RCNN代码解读(0)

Fast-RCNN代码解读(1)由于博主最近正在尝试修改rbg大神的fast-rcnn代码,为了推进自己的学习进度,在此做一个记录,同时也便于跟大家交流。...
  • applecore123456
  • applecore123456
  • 2016年10月12日 20:35
  • 535

r-cnn系列代码编译及解读(1)

本系列针对RBG在github上的fast r-cnn代码,做安装配置及解读工作 本文解决由于CAFFE版本的更新导致的fast r-cnn编译失败的问题...
  • zizi7
  • zizi7
  • 2017年05月02日 19:27
  • 455

faster rcnn 源码解读

faster rcnn 源码解读faster rcnn同fast rcnn相比,就是将ss(候选框提取)的算法融合到了网络总,这样可以在网络中共享卷积层,计算效率更高。这里讲解了faster rcnn...
  • bailufeiyan
  • bailufeiyan
  • 2016年02月26日 14:46
  • 13526

fasterRCNN详解

R-CNN学习笔记:http://blog.csdn.NET/xzzppp/article/details/51345742 Fast R-CNN学习笔记:http://blog.csdn.net/...
  • u013126125
  • u013126125
  • 2016年10月19日 14:51
  • 2770

Faster-RCNN代码+理论——1

昨天刚参加完一个ibm的医疗影像大赛——我负责的模型是做多目标识别并输出位置的模型。由于之前没有什么经验,采用了在RGB图像上表现不错的Faster-RCNN,但是比赛过程表明:效果不是很好。所以这里...
  • g11d111
  • g11d111
  • 2017年12月17日 12:23
  • 332

解读ssd中训练代码中知识点

作为一个渗透测试学习者来说,对系统的足够了解是基本的要求,下面就通过对os.environ中的key解读的角度来认识系统。 windows: · os.environ['HOMEPATH']:...
  • SMF0504
  • SMF0504
  • 2017年05月11日 22:00
  • 491

Faster RCNN代码理解(Python) ---训练过程

最近开始学习深度学习,看了下Faster RCNN的代码,在学习的过程中也查阅了很多其他人写的博客,得到了很大的帮助,所以也打算把自己一些粗浅的理解记录下来,一是记录下自己的菜鸟学习之路,方便自己过后...
  • u014696921
  • u014696921
  • 2017年03月04日 09:56
  • 4224

【E2LSH源码分析】E2LSH源码综述及主要数据结构

E2LSH的核心代码可以分为3部分: LocalitySensitiveHashing.cpp——主要包含基于LSH的RNN(R-near neighbor)数据结构。其主要功能是根据参数构建数据结构...
  • JasonDing1354
  • JasonDing1354
  • 2014年08月01日 21:27
  • 2379

fasterrcnn测试

fasterrcnn测试
  • forest_world
  • forest_world
  • 2017年03月13日 11:01
  • 219
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:FasterRCNN代码解读
举报原因:
原因补充:

(最多只允许输入30个字)