faster rcnn代码解读(九)rcnn_head

faster rcnn代码解读参考

https://github.com/adityaarun1/pytorch_fast-er_rcnn

    https://github.com/jwyang/faster-rcnn.pytorch

rcnn loss与rpn loss计算方式差不多,rcnn的最后一个组合环节就是rcnn_head,也就是把rpn环节生成的proposal与这是目标进行对齐

class RcnnHead(nn.Module):
    def __init__(self,nclasses,in_channels=4096):
        super(RcnnHead, self).__init__()
        self.n_classes =nclasses
        self.rcnn_roi_align = RoIAlign((cfg['pooling_size'], cfg['pooling_size']),1.0 / 16.0, 0)
        self.rcnn_roi_pool = RoIPool((cfg['pooling_size'], cfg['pooling_size']), 1.0 / 16.0)
        self.cls_score_net = nn.Linear(in_channels, self.n_classes)
        self.bbox_pred_net = nn.Linear(in_channels, 4 * self.n_classes)

    def forward(self,base_feat, rois ,vgg_classifier,mode):
        
        if cfg['pooling_mode'] == 'crop':
            # pdb.set_trace()
            # pooled_feat_anchor = _crop_pool_layer(base_feat, rois.view(-1, 5))
            grid_xy = affine_grid_gen(rois.view(-1, 5), base_feat.size()[2:], self.grid_size)
            grid_yx = torch.stack([grid_xy.data[:,:,:,1], grid_xy.data[:,:,:,0]], 3).contiguous()
            # pooled_feat = self.RCNN_roi_crop(base_feat, Variable(grid_yx).detach())
            # if cfg.CROP_RESIZE_WITH_MAX_POOL:
            #     pooled_feat = F.max_pool2d(pooled_feat, 2, 2)
        elif cfg['pooling_mode'] == 'align':
            # print('align',base_feat.shape,rois.shape)
            pooled_feat = self.rcnn_roi_align(base_feat, rois.view(-1,5))
            # print('pooled_feat',pooled_feat.shape)
        elif cfg['pooling_mode'] == 'pool':
            pooled_feat = self.rcnn_roi_pool(base_feat, rois.view(-1,5))
        # print('pooled_feat',pooled_feat)

        pool5_feat = pooled_feat.view(base_feat.size(0),pooled_feat.size(0), -1)
        # print('pool5_feat',pool5_feat.shape)
        fc7 = vgg_classifier(pool5_feat)
        # compute object classification probability
        # print('fc7', fc7.shape)
        rcnn_cls_score = self.cls_score_net(fc7)
        rcnn_cls_pred = torch.max(rcnn_cls_score, 2)[1]
        # print('rcnn_head rcnn_cls_pred_max', rcnn_cls_pred)
        # print('rcnn_cls_score', rcnn_cls_score.shape)

        rcnn_cls_prob = F.softmax(rcnn_cls_score, 2)
        # compute bbox offset
        rcnn_bbox_pred = self.bbox_pred_net(fc7)       
        return rcnn_cls_pred , rcnn_cls_prob,rcnn_cls_score,rcnn_bbox_pred

这里又好几种pooling对齐方案,这里只用到了align,这是torchcvision自带的,因此主要的目的就是对其。

对其之后其他的才做都比较好理解,这里有一个vgg_classifier并不是非这么写不可,只不过作者觉得这么写起来把vgg都用起来了。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值