py-faster-rcnn源码解读系列(四)——anchor_target_layer.py

本文介绍了在solver中出现的用python定义的layer,顾名思义,该layer主要功能是产生anchor,并对anchor进行评分等操作,详细见代码注释。

 class AnchorTargetLayer(caffe.Layer):
"""
Assign anchors to ground-truth targets. Produces anchor classification
labels and bounding-box regression targets.
"""

setup函数

首先读取了,在.prototxt中定义的相关参数,事实上只有feat\_stride,一般被定义为16.
然后设置了相关参数比如\_anchors,由一个工具py中的方法generate\_anchors产生,通常为如下九个,有兴趣的读者不妨在纸上画一画,便可知道其中奥秘,在这里卖个关子:)
 anchors =  
 (xmin  ymin xmax ymax)
 -83   -39   100    56
 -175   -87   192   104
 -359  -183   376   200
 -55   -55    72    72
 -119  -119   136   136
 -247  -247   264   264
 -35   -79    52    96
 -79  -167    96   184
 -167  -343   184   360
以及一些其他需要用到的的属性。
def setup(self, bottom, top):
    layer_params = yaml.load(self.param_str_)
    anchor_scales = layer_params.get('scales', (8, 16, 32))
    self._anchors = generate_anchors(scales=np.array(anchor_scales))
    self._num_anchors = self._anchors.shape[0]
    self._feat_stride = layer_params['feat_stride']
    #fg指的是前景 fore ground   bg指的是背景 back ground
    self._counts = cfg.EPS
    self._sums = np.zeros((1, 4))
    self._squared_sums = np.zeros((1, 4))
    self._fg_sum = 0
    self._bg_sum = 0
    self._count = 0

    # allow boxes to sit over the edge by a small amount
    self._allowed_border = layer_params.get('allowed_border', 0)

    height, width = bottom[0].data.shape[-2:]
    #A 一般为 9
    A = self._num_anchors
    # 在这里将top的维度结构reshape
    # labels
    top[0].reshape(1, 1, A * height, width)
    # bbox_targets
    top[1].reshape(1, A * 4, height, width)
    # bbox_inside_weights
    top[2].reshape(1, A * 4, height, width)
    # bbox_outside_weights
    top[3].reshape(1, A * 4, height, width)

forward

前向传播:
在函数开头的注释已经阐述的很清楚了,对于每一个(H,W)位置点,都产生九个不同形状的anchor,在网络结构定义中H=61,W=36你会发现这里的H x feat_stride以及W x feat_stride正好约等于rescale以后的每张图的大小,好像是(900 x 533)?
然后仅仅保留范围在原图中的anchor,大概裁掉了2/3这样,并分别计算这些anchor与每个ground truth的重合度。

def forward(self, bottom, top):
    # Algorithm:
    #
    # for each (H, W) location i
    #   generate 9 anchor boxes centered on cell i
    #   apply predicted bbox deltas at cell i to each of the 9 anchors
    # filter out-of-image anchors
    # measure GT overlap    
    # map of shape (..., H, W)
    height, width = bottom[0].data.shape[-2:]
    # GT boxes (x1, y1, x2, y2, label)
    gt_boxes = bottom[1].data
    # im_info
    im_info = bottom[2].data[0, :]

    #在61 x 36每一个位置点上生成九个anchor,你可以想象成在一张图中均匀地取了61 x 36个点,然后 shift_x和shift_y分别是这些点在图中的偏移位置,让这些偏移值加上每个anchor的四个坐标点。然后就获得了一个all_anchors,一个(K*A,4)大的二维数组。
    # 1. Generate proposals from bbox deltas and shifted anchors
    shift_x = np.arange(0, width) * self._feat_stride
    shift_y = np.arange(0, height) * self._feat_stride
    shift_x, shift_y = np.meshgrid(shift_x, shift_y)
    shifts = np.vstack((shift_x.ravel(), shift_y.ravel(),
    shift_x.ravel(), shift_y.ravel())).transpose()
    # add A anchors (1, A, 4) to
    # cell K shifts (K, 1, 4) to get
    # shift anchors (K, A, 4)
    # reshape to (K*A, 4) shifted anchors
    A = self._num_anchors
    K = shifts.shape[0]
    all_anchors = (self._anchors.reshape((1, A, 4)) +
    shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
    all_anchors = all_anchors.reshape((K * A, 4))
    total_anchors = int(K * A)
    #裁掉大小超出图片的anchor,inds_inside是在图像内部的anchor的索引数组
    # only keep anchors inside the image
    inds_inside = np.where(
    (all_anchors[:, 0] >= -self._allowed_border) &
    (all_anchors[:, 1] >= -self._allowed_border) &
    (all_anchors[:, 2] < im_info[1] + self._allowed_border) &  # width
    (all_anchors[:, 3] < im_info[0] + self._allowed_border)    # height
    )[0]
    # keep only inside anchors
    anchors = all_anchors[inds_inside, :]

    # label: 1 is positive, 0 is negative, -1 is dont care
    labels = np.empty((len(inds_inside), ), dtype=np.float32)
    labels.fill(-1)
    #这里overlaps是计算所有anchor与ground-truth的重合度,它是一个len(anchors) x len(gt_boxes)的二维数组,每个元素是各个anchor和gt_boxes的overlap值,这个overlap值的计算是这样的:
    overlap = (重合部分面积) / (anchor面积 + gt_boxes面积 - 重合部分面积)
    · argmax_overlaps是每个anchor对应最大overlap的gt_boxes的下标
    · max_overlaps是每个anchor对应最大的overlap值
    相对应的
    · gt_argmax_overlaps是每个gt_boxes对应最大overlap的anchor的下标
    · gt_max_overlaps是每个gt_boxes对应最大的overlap值
    # overlaps between the anchors and the gt boxes
    # overlaps (ex, gt)
    overlaps = bbox_overlaps(
    np.ascontiguousarray(anchors, dtype=np.float),
    np.ascontiguousarray(gt_boxes, dtype=np.float))
    argmax_overlaps = overlaps.argmax(axis=1)
    max_overlaps = overlaps[np.arange(len(inds_inside)), argmax_overlaps]
    gt_argmax_overlaps = overlaps.argmax(axis=0)
    gt_max_overlaps = overlaps[gt_argmax_overlaps,
    np.arange(overlaps.shape[1])]
    #加上这一步是因为有很多overlap并列第一
    gt_argmax_overlaps = np.where(overlaps == gt_max_overlaps)[0]
    #接下来是打标签的工作
    if not cfg.TRAIN.RPN_CLOBBER_POSITIVES:
    # assign bg labels first so that positive labels can clobber them
    labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0

    # fg label: for each gt, anchor with highest overlap
    labels[gt_argmax_overlaps] = 1

    # fg label: above threshold IOU
    labels[max_overlaps >= cfg.TRAIN.RPN_POSITIVE_OVERLAP] = 1

    if cfg.TRAIN.RPN_CLOBBER_POSITIVES:
    # assign bg labels last so that negative labels can clobber positives
    labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0

    #接下来两步工作是为了让正样本与负样本严格保持1:1
    # subsample positive labels if we have too many
    num_fg = int(cfg.TRAIN.RPN_FG_FRACTION * cfg.TRAIN.RPN_BATCHSIZE)
    fg_inds = np.where(labels == 1)[0]
    if len(fg_inds) > num_fg:
    disable_inds = npr.choice(
    fg_inds, size=(len(fg_inds) - num_fg), replace=False)
    labels[disable_inds] = -1

    # subsample negative labels if we have too many
    num_bg = cfg.TRAIN.RPN_BATCHSIZE - np.sum(labels == 1)
    bg_inds = np.where(labels == 0)[0]
    if len(bg_inds) > num_bg:
    disable_inds = npr.choice(
    bg_inds, size=(len(bg_inds) - num_bg), replace=False)
    labels[disable_inds] = -1
    #print "was %s inds, disabling %s, now %s inds" % (
    #len(bg_inds), len(disable_inds), np.sum(labels == 0))

    #这里将计算每一个anchor与重合度最高的ground_truth的偏移值,详细的计算方法在论文中提到,在fast-rcnn/bbox_transform.py中的bbox_transform函数也非常容易看懂
    bbox_targets = np.zeros((len(inds_inside), 4), dtype=np.float32)
    bbox_targets = _compute_targets(anchors, gt_boxes[argmax_overlaps, :])

    #这里是inside_weight和out_weight的计算。- -#不过好像全程都是1
    bbox_inside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32)
    bbox_inside_weights[labels == 1, :] = np.array(cfg.TRAIN.RPN_BBOX_INSIDE_WEIGHTS)

    bbox_outside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32)
    if cfg.TRAIN.RPN_POSITIVE_WEIGHT < 0:
    # uniform weighting of examples (given non-uniform sampling)
    num_examples = np.sum(labels >= 0)
    positive_weights = np.ones((1, 4)) * 1.0 / num_examples
    negative_weights = np.ones((1, 4)) * 1.0 / num_examples
    else:
    assert ((cfg.TRAIN.RPN_POSITIVE_WEIGHT > 0) &
    (cfg.TRAIN.RPN_POSITIVE_WEIGHT < 1))
    positive_weights = (cfg.TRAIN.RPN_POSITIVE_WEIGHT /
    np.sum(labels == 1))
    negative_weights = ((1.0 - cfg.TRAIN.RPN_POSITIVE_WEIGHT) /
    np.sum(labels == 0))
    bbox_outside_weights[labels == 1, :] = positive_weights
    bbox_outside_weights[labels == 0, :] = negative_weights


    #还记得文初将all_anchors裁减掉了2/3左右,仅仅保留在图像内的anchor吗,这里就是将其复原作为下一层的输入了,并reshape成相应的格式
    # map up to original set of anchors
    labels = _unmap(labels, total_anchors, inds_inside, fill=-1)
    bbox_targets = _unmap(bbox_targets, total_anchors, inds_inside, fill=0)
    bbox_inside_weights = _unmap(bbox_inside_weights, total_anchors, inds_inside, fill=0)
    bbox_outside_weights = _unmap(bbox_outside_weights, total_anchors, inds_inside, fill=0)


    # labels
    labels = labels.reshape((1, height, width, A)).transpose(0, 3, 1, 2)
    labels = labels.reshape((1, 1, A * height, width))
    top[0].reshape(*labels.shape)
    top[0].data[...] = labels

    # bbox_targets
    bbox_targets = bbox_targets \
    .reshape((1, height, width, A * 4)).transpose(0, 3, 1, 2)
    top[1].reshape(*bbox_targets.shape)
    top[1].data[...] = bbox_targets

    # bbox_inside_weights
    bbox_inside_weights = bbox_inside_weights \
    .reshape((1, height, width, A * 4)).transpose(0, 3, 1, 2)
    assert bbox_inside_weights.shape[2] == height
    assert bbox_inside_weights.shape[3] == width
    top[2].reshape(*bbox_inside_weights.shape)
    top[2].data[...] = bbox_inside_weights

    # bbox_outside_weights
    bbox_outside_weights = bbox_outside_weights \
    .reshape((1, height, width, A * 4)).transpose(0, 3, 1, 2)
    assert bbox_outside_weights.shape[2] == height
    assert bbox_outside_weights.shape[3] == width
    top[3].reshape(*bbox_outside_weights.shape)
    top[3].data[...] = bbox_outside_weights

    def backward(self, top, propagate_down, bottom):
    """This layer does not propagate gradients."""
    pass

    def reshape(self, bottom, top):
    """Reshaping happens during the call to forward."""
    pass

_unmap

上个函数将all_anchors裁减掉了2/3左右,仅仅保留在图像内的anchor,这里就是将其复原作为下一层的输入了,并reshape成相应的格式


def _unmap(data, count, inds, fill=0):
    """ Unmap a subset of item (data) back to the original set of items (of
    size count) """
    if len(data.shape) == 1:
    ret = np.empty((count, ), dtype=np.float32)
    ret.fill(fill)
    ret[inds] = data
    else:
    ret = np.empty((count, ) + data.shape[1:], dtype=np.float32)
    ret.fill(fill)
    ret[inds, :] = data
    return ret

_compute_targets

计算与每个anchor最大重合度的ground-truth的(x,y,width,height)的偏移值


def _compute_targets(ex_rois, gt_rois):
"""Compute bounding-box regression targets for an image."""

assert ex_rois.shape[0] == gt_rois.shape[0]
assert ex_rois.shape[1] == 4
assert gt_rois.shape[1] == 5

return bbox_transform(ex_rois, gt_rois[:, :4]).astype(np.float32, copy=False)
  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值