以下仅为个人理解,若有不正之处还请指出,欢迎交流!
一、整体过程
mmdet/models/detectors/two_stage.py中的部分代码
if self. with_rpn:
rpn_outs = self. rpn_head( x)
rpn_loss_inputs = rpn_outs + ( gt_bboxes, img_meta,
self. train_cfg. rpn)
rpn_losses = self. rpn_head. loss(
* rpn_loss_inputs, gt_bboxes_ignore= gt_bboxes_ignore)
losses. update( rpn_losses)
proposal_cfg = self. train_cfg. get( 'rpn_proposal' ,
self. test_cfg. rpn)
proposal_inputs = rpn_outs + ( img_meta, proposal_cfg)
proposal_list = self. rpn_head. get_bboxes( * proposal_inputs)
else :
proposal_list = proposals
二、详细解读
1.RPN(rpn_head)网络结构
网络结构示意图如下图所示: 代码实现
class RPNHead ( AnchorHead) :
def __init__ ( self, in_channels, ** kwargs) :
super ( RPNHead, self) . __init__( 2 , in_channels, ** kwargs)
def _init_layers ( self) :
self. rpn_conv = nn. Conv2d(
self. in_channels, self. feat_channels, 3 , padding= 1 )
self. rpn_cls = nn. Conv2d( self. feat_channels,
self. num_anchors * self. cls_out_channels, 1 )
self. rpn_reg = nn. Conv2d( self. feat_channels, self. num_anchors * 4 , 1 )
def init_weights ( self) :
normal_init( self. rpn_conv, std= 0.01 )
normal_init( self. rpn_cls, std= 0.01 )
normal_init( self. rpn_reg, std= 0.01 )
def forward_single ( self, x) :
x = self. rpn_conv( x)
x = F. relu( x, inplace= True )
rpn_cls_score = self. rpn_cls( x)
rpn_bbox_pred = self. rpn_reg( x)
return rpn_cls_score, rpn_bbox_pred
2.RPN_Loss
rpn_losses = self. rpn_head. loss( * rpn_loss_inputs, gt_bboxes_ignore= gt_bboxes_ignore)
def loss ( self,
cls_scores,
bbox_preds,
gt_bboxes,
img_metas,
cfg,
gt_bboxes_ignore= None ) :
losses = super ( RPNHead, self) . loss(
cls_scores,
bbox_preds,
gt_bboxes,
None ,
img_metas,
cfg,
gt_bboxes_ignore= gt_bboxes_ignore)
return dict (
loss_rpn_cls= losses[ 'loss_cls' ] , loss_rpn_bbox= losses[ 'loss_bbox' ] )
def loss ( self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore= None ) :
featmap_sizes = [ featmap. size( ) [ - 2 : ] for featmap in cls_scores]
assert len ( featmap_sizes) == len ( self. anchor_generators)
device = cls_scores[ 0 ] . device
anchor_list, valid_flag_list = self. get_anchors(
featmap_sizes, img_metas, device= device)
label_channels = self. cls_out_channels if self. use_sigmoid_cls else 1
cls_reg_targets = anchor_target(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
self. target_means,
self. target_stds,
cfg,
gt_bboxes_ignore_list= gt_bboxes_ignore,
gt_labels_list= gt_labels,
label_channels= label_channels,
sampling= self. sampling)
if cls_reg_targets is None :
return None
( labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (
num_total_pos + num_total_neg if self. sampling else num_total_pos)
losses_cls, losses_bbox = multi_apply(
self. loss_single,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples= num_total_samples,
cfg= cfg)
return dict ( loss_cls= losses_cls, loss_bbox= losses_bbox)
def loss_single ( self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg) :
labels = labels. reshape( - 1 )
label_weights = label_weights. reshape( - 1 )
cls_score = cls_score. permute( 0 , 2 , 3 ,
1 ) . reshape( - 1 , self. cls_out_channels)
loss_cls = self. loss_cls(
cls_score, labels, label_weights, avg_factor= num_total_samples)
bbox_targets = bbox_targets. reshape( - 1 , 4 )
bbox_weights = bbox_weights. reshape( - 1 , 4 )
bbox_pred = bbox_pred. permute( 0 , 2 , 3 , 1 ) . reshape( - 1 , 4 )
loss_bbox = self. loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
avg_factor= num_total_samples)
return loss_cls, loss_bbox
更多细节不再展开,概括地讲,RPN网络会根据设定的尺度和纵横比生成大量anchor box(anchor box介绍可参考理解anchor box究竟是如何生成的 ),通过anchor box与gt box之间的IoU值来判定anchor box的真实标签,并进行采样,例如一个mini-batch选择256个anchor,设定正负样本比例1:1,再计算标记为正样本的anchor box相对于gt box的真实坐标偏移量。然后通过训练使loss收敛,以使RPN网络的预测结果更接近真实结果。
3.生成区域建议候选框(proposals)
proposal_list = self. rpn_head. get_bboxes( * proposal_inputs)
def get_bboxes ( self,
cls_scores,
bbox_preds,
img_metas,
cfg,
rescale= False ) :
assert len ( cls_scores) == len ( bbox_preds)
num_levels = len ( cls_scores)
device = cls_scores[ 0 ] . device
mlvl_anchors = [
self. anchor_generators[ i] . grid_anchors(
cls_scores[ i] . size( ) [ - 2 : ] ,
self. anchor_strides[ i] ,
device= device) for i in range ( num_levels)
]
result_list = [ ]
for img_id in range ( len ( img_metas) ) :
cls_score_list = [
cls_scores[ i] [ img_id] . detach( ) for i in range ( num_levels)
]
bbox_pred_list = [
bbox_preds[ i] [ img_id] . detach( ) for i in range ( num_levels)
]
img_shape = img_metas[ img_id] [ 'img_shape' ]
scale_factor = img_metas[ img_id] [ 'scale_factor' ]
proposals = self. get_bboxes_single( cls_score_list, bbox_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
result_list. append( proposals)
return result_list
def get_bboxes_single ( self,
cls_scores,
bbox_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale= False ) :
mlvl_proposals = [ ]
for idx in range ( len ( cls_scores) ) :
rpn_cls_score = cls_scores[ idx]
rpn_bbox_pred = bbox_preds[ idx]
assert rpn_cls_score. size( ) [ - 2 : ] == rpn_bbox_pred. size( ) [ - 2 : ]
rpn_cls_score = rpn_cls_score. permute( 1 , 2 , 0 )
if self. use_sigmoid_cls:
rpn_cls_score = rpn_cls_score. reshape( - 1 )
scores = rpn_cls_score. sigmoid( )
else :
rpn_cls_score = rpn_cls_score. reshape( - 1 , 2 )
scores = rpn_cls_score. softmax( dim= 1 ) [ : , 1 ]
rpn_bbox_pred = rpn_bbox_pred. permute( 1 , 2 , 0 ) . reshape( - 1 , 4 )
anchors = mlvl_anchors[ idx]
if cfg. nms_pre > 0 and scores. shape[ 0 ] > cfg. nms_pre:
_, topk_inds = scores. topk( cfg. nms_pre)
rpn_bbox_pred = rpn_bbox_pred[ topk_inds, : ]
anchors = anchors[ topk_inds, : ]
scores = scores[ topk_inds]
proposals = delta2bbox( anchors, rpn_bbox_pred, self. target_means,
self. target_stds, img_shape)
if cfg. min_bbox_size > 0 :
w = proposals[ : , 2 ] - proposals[ : , 0 ] + 1
h = proposals[ : , 3 ] - proposals[ : , 1 ] + 1
valid_inds = torch. nonzero( ( w >= cfg. min_bbox_size) &
( h >= cfg. min_bbox_size) ) . squeeze( )
proposals = proposals[ valid_inds, : ]
scores = scores[ valid_inds]
proposals = torch. cat( [ proposals, scores. unsqueeze( - 1 ) ] , dim= - 1 )
proposals, _ = nms( proposals, cfg. nms_thr)
proposals = proposals[ : cfg. nms_post, : ]
mlvl_proposals. append( proposals)
proposals = torch. cat( mlvl_proposals, 0 )
if cfg. nms_across_levels:
proposals, _ = nms( proposals, cfg. nms_thr)
proposals = proposals[ : cfg. max_num, : ]
else :
scores = proposals[ : , 4 ]
num = min ( cfg. max_num, proposals. shape[ 0 ] )
_, topk_inds = scores. topk( num)
proposals = proposals[ topk_inds, : ]
return proposals
该部分即为获取区域建议候选框的过程,首先重新生成所有anchor box,并根据预测置信度进行初步筛选,将anchor box叠加rpn网络预测的坐标偏移量得到更为准确的候选框,再通过NMS(非最大值抑制)得到最终的区域建议候选框。 由于源代码中RPN网络输入多个feaure map,所以更详细地分析会略微复杂一些,但原理都是相同的。