Faster RCNN代码详解(五):关于检测网络(Fast RCNN)的proposal

Faster RCNN代码详解(二):网络结构构建中介绍了Faster RCNN算法的网络结构,其中有一个用于生成ROI proposal target的自定义层,该自定义层的输出作为检测网络(Fast RCNN)的输入,这篇博客就来介绍这个自定义层的内容。

该自定义层的实现所在脚本~/mx-rcnn/rcnn/symbol/proposal_target.py,该层返回的group列表包含4个值,分别是rois,label,bbox_target,bbox_weight。roi用于ROI Pooling层,label用于检测网络的分类支路、bbox_target和bbox_weight用于检测网络的回归支路

通过系列四中对RPN网络中anchor的介绍,你应该明白这里的label、bbox_target、bbox_weight和RPN网络中的不同,RPN网络中的label、bbox_target和bbox_weight等变量的定义方式和这里不同,同时在RPN网络中那边变量是服务于anchor的。

"""
Proposal Target Operator selects foreground and background roi and assigns label, bbox_transform to them.
"""

import logging
import mxnet as mx
import numpy as np
from distutils.util import strtobool

from ..logger import logger
from rcnn.io.rcnn import sample_rois

class ProposalTargetOperator(mx.operator.CustomOp):
    def __init__(self, num_classes, batch_images, batch_rois, fg_fraction):
        super(ProposalTargetOperator, self).__init__()
        self._num_classes = num_classes
        self._batch_images = batch_images
        self._batch_rois = batch_rois
        self._fg_fraction = fg_fraction

        if logger.level == logging.DEBUG:
            self._count = 0
            self._fg_num = 0
            self._bg_num = 0

    def forward(self, is_train, req, in_data, out_data, aux):
        assert self._batch_rois % self._batch_images == 0, \
            'BATCHIMAGES {} must devide BATCH_ROIS {}'.format(self._batch_images, self._batch_rois)
        rois_per_image = int(self._batch_rois / self._batch_images)
        fg_rois_per_image = int(round(self._fg_fraction * rois_per_image))


# all_rois的维度是(2000,5),不过all_rois除了4列坐标外,剩下一列全是0,
# 并不表示roi的标签,仅仅是batch的index标识。gt_boxes的维度是(x,5),x是object的数量。
        all_rois = in_data[0].asnumpy()
        gt_boxes = in_data[1].asnumpy()

        # Include ground-truth boxes in the set of candidate rois
# 初始化的zeros替换掉gt_boxes中object的类别,然后和原来的all_rois做合并,
# 最后得到的all_rois的维度是(2000+x,5)。因为all_rois变量中并不需要ground truth的标签,
# 所以都用0值替代。从最后的assert语句也可以看出第一列0值的含义是和batch相关。
        zeros = np.zeros((gt_boxes.shape[0], 1), dtype=gt_boxes.dtype)
        all_rois = np.vstack((all_rois, np.hstack((zeros, gt_boxes[:, :-1]))))
        # Sanity check: single batch only
        assert</
  • 5
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值