从源码解读Faster-RCNN--(4)位置预测分类网络和损失函数

上面那一篇文章,我们介绍了RPN网络,这一篇我们重点介绍对产生的rois处理的位置预测以及分类网络,和损失函数。
代码参考:simple-faster-rcnn-pytorch
博客参考:Faster_RCNN 3.模型准备(下)

位置预测以及分类网络

得到训练的候选框

接着上一篇文章,先看代码trainer.py:

#trainer.py
sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator(
            roi,
            at.tonumpy(bbox),
            at.tonumpy(label),
            self.loc_normalize_mean,
            self.loc_normalize_std)#117
sample_roi_index = t.zeros(len(sample_roi))
roi_cls_loc, roi_score = self.faster_rcnn.head(
    features,
    sample_roi,
    sample_roi_index)

先整体看就是通过proposal_target_creator去从RPN产生的2000个候选框中选择一些sample_roi,送入faster_rcnn.head()中进行位置和类别的预测。其中对候选框的筛选上,我们用到了一些ground truth来进行辅助筛选,因为这部分我们重点训练分类网络faster_rcnn.head(),所以需要一些高置信度的候选框。
这里看proposal_target_creator()函数,首先找到其实现:

#model/utils/creator_tool.py
class ProposalTargetCreator(object):#8
	...
    def __call__(self, roi, bbox, label,...):#43
    	...
    	iou = bbox_iou(roi, bbox)#97
    	...
    	pos_index = np.where(max_iou >= self.pos_iou_thresh)[0]#105
    	...
    	neg_index = np.where((max_iou < self.neg_iou_thresh_hi)....#113
 		keep_index = np.append(pos_index, neg_index)
        gt_roi_label = gt_roi_label[keep_index]
        gt_roi_label[pos_roi_per_this_image:] = 0  # negative labels --> 0
        sample_roi = roi[keep_index]   	

可以看到,函数中利用预测的候选框和真实的bbox的进行IOU,找到候选框roi对应的ground truth bbox位置和标签,然后选取一些正例和反例进行分类网络的训练。当然这里面也进行了一些筛选,将原来的2000个候选框中采样指定的个数,例如self.n_sample=128:

#model/utils/creator_tool.py
def __call__(self, roi, bbox, label,...):#43
	pos_index = np.where(max_iou >= self.pos_iou_thresh)[0]#105
	...
	neg_index = np.where((max_iou < self.neg_iou_thresh_hi) &
                             (max_iou >= self.neg_iou_thresh_lo))[0]#113
    neg_roi_per_this_image = self.n_sample - pos_roi_per_this_image
    ...
送入网络

在trainer.py向下看

#trainer.py
roi_cls_loc, roi_score = self.faster_rcnn.head(
    features,
    sample_roi,
    sample_roi_index)#120

这就是位置预测与分类的网络,我们看网络中的内容:

#model/faster_rcnn_vgg16.py
class VGG16RoIHead(nn.Module):#86
	...
	def forward(self, x, rois, roi_indices):#117
        # in case roi_indices is  ndarray
		...
        indices_and_rois =  xy_indices_and_rois.contiguous()

        pool = self.roi(x, indices_and_rois)
        pool = pool.view(pool.size(0), -1)
        fc7 = self.classifier(pool)
        roi_cls_locs = self.cls_loc(fc7)
        roi_scores = self.score(fc7)
        return roi_cls_locs, roi_scores		

这里面按照roi_indices以及rois进行预测,其中具体前向和后向的实现,作者在model/utils/roi_cupy.py实现,感兴趣的可以去看这一部分。

损失函数

RPN网络的损失函数

首先看代码

#trainer.py
        gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(
            at.tonumpy(bbox),
            anchor,
            img_size)#126
        gt_rpn_label = at.totensor(gt_rpn_label).long()
        gt_rpn_loc = at.totensor(gt_rpn_loc)
        rpn_loc_loss = _fast_rcnn_loc_loss(
            rpn_loc,
            gt_rpn_loc,
            gt_rpn_label.data,
            self.rpn_sigma)

这里面是将所有的anchor传入,然后找到所有的anchor对应的loc和label对应的ground truth,将之前得到anchors的rpn_loc一起通过_fast_rcnn_loc_loss计算损失。
这里看下anchor_target_creator的实现:

#model/utils/creator_tool.py
class AnchorTargetCreator(object):#136
	...
	def __call__(self, bbox, anchor, img_size):
		...
		inside_index = _get_inside_index(anchor, img_H, img_W)#203
		...
		argmax_ious, label = self._create_label(...)#205
		loc = bbox2loc(anchor, bbox[argmax_ious])
		label = _unmap(label, n_anchor, inside_index, fill=-1)
		loc = _unmap(loc, n_anchor, inside_index, fill=0)
		return loc, label

通过_get_inside_index找到在图像中的anchor,这里的label是二分类。那么对于前景和背景的损失呢?看下面

#trainer.py
rpn_cls_loss = F.cross_entropy(rpn_score, gt_rpn_label.cuda(), ignore_index=-1)#139

这里用到了我们在rpn网络中用到的rpn_score。

ROI的损失函数

因为我们在self.proposal_target_creator()已经产生了采样的roi以及其对应的位置和分类的ground truth,这里我们直接使用就好:

#trainer.py
        roi_loc_loss = _fast_rcnn_loc_loss(
            roi_loc.contiguous(),
            gt_roi_loc,
            gt_roi_label.data,
            self.roi_sigma)#152

        roi_cls_loss = nn.CrossEntropyLoss()(roi_score, gt_roi_label.cuda())#158

这样,我们就得到了所有的损失,最后对他们就好。

#trainer.py
		losses = [rpn_loc_loss, rpn_cls_loss, roi_loc_loss, roi_cls_loss]#162
        losses = losses + [sum(losses)]

整体结构

我们自此已经把整个过程已经过了一篇,细节部分可能没有细抠,这个感兴趣的可以仔细钻研一下,也不难。
faster rcnn
所以,整体上faster rcnn新加上的网络也不复杂,就几个层。
net
上面的两个框就是faster rcnn加上的几层网络。
整体上,faster rcnn就是利用anchor产生很多个候选框,然后基于网络矫正,这就是再筛选掉一部分,这就是RPN的作用,然后对剩下的rois利用gound truth再筛选掉一部分用作分类和定位网络。
如有错误,还请您指出,谢谢!

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值