Relation Networks for Object Detection

Relation Networks for Object Detection为 CVPR 2018 oral paper.

论文链接:https://arxiv.org/pdf/1711.11575.pdf

开源代码:https://github.com/msracver/Relation-Networks-for-Object-Detection

主要贡献:受Attention is all you need的影响,设计了relation模块去找物体之间的联系,在文中是在找rpn产生的proposals之间的appearance feature and geometry关系,从而提高目标检测算法的性能。就像出现太阳的图片一般不会出现月亮。有些物体很大几率一起出现,而有些物体基本不会出现在一起。最后设计了一个Relation for Duplicate Removal模块去替代rcnn阶段的nms操作。作者认为, NMS uses simple relations between bounding boxes and scores,设计的Duplicate Removal可以更好地去去掉那些应该去掉的boxes。

网络框架图如下图。

 

对于relation模块,图来自:https://blog.sunnyan.cn/2018/04/18/relation-networks-for-object-detection/

作者设置 Nr=16,dk=64,dg=64Nr=16,dk=64,dg=64。输入的 proposals 数量 N = 300

relation module的具体公式表现可以见文章中的第三节Object Relation Module或者看https://blog.sunnyan.cn/2018/04/18/relation-networks-for-object-detection/讲的很清楚,下面参考https://blog.csdn.net/u014380165/article/details/80779712给出伪代码 

#求rois之间的坐标关系,找到位置信息
rois=[300,4]
nongt_dim = 300
position_matrix = extract_position_matrix(rois, nongt_dim=nongt_dim) #position_matrix=[300,300,4]

#extract_position_embedding方法实现论文中公式5的EG操作
position_embedding = extract_position_embedding(position_matrix, feat_dim=64) #[300,300,64]
def extract_position_embedding(position_mat=position_matrix, feat_dim=16, wave_length=1000):
        feat_range = arange(0, feat_dim / 8) #[0,1,2,3,4,5,6,7]
        dim_mat = broadcast_power(lhs=full((1,), wave_length), rhs=(8. / feat_dim) * feat_range)#[1., 2.37137365, 5.62341309,
# 13.33521461, 31.62277603, 74.98941803, 177.82794189, 421.69650269]
        dim_mat = Reshape(dim_mat, shape=(1, 1, 1, -1)) #[1,1,1,8]
        position_mat = expand_dims(100.0 * position_mat, axis=3) #[300,300,4,1]
        div_mat = broadcast_div(lhs=position_mat, rhs=dim_mat) #[300,300,4,8]
        sin_mat = sin(data=div_mat)
        cos_mat = cos(data=div_mat)
        embedding = concat(sin_mat, cos_mat, dim=3) #[300,300,4,16]
        embedding = Reshape(embedding, shape=(0, 0, feat_dim)) #[300,300,64]
        return embedding

roi_pool=[300,256,7,7]
fc_new_1 = fc(roi_pool) #fc[256*7*7, 1024] fc_new_1[300,1024]
attention_1 = attention_module_multi_head(fc_new_1, position_embedding, nongt_dim=nongt_dim, fc_dim=16, feat_dim=1024, index=1, group=16, dim=(1024, 1024, 1024))

def attention_module_multi_head(self, roi_feat, position_embedding,
                                    nongt_dim, fc_dim, feat_dim,
                                    dim=(1024, 1024, 1024),
                                    group=16, index=1):

        dim_group = (dim[0] / group, dim[1] / group, dim[2] / group) #(64, 64, 64)
        nongt_roi_feat = slice_axis(data=roi_feat, axis=0, begin=0, end=nongt_dim) #[300,1024]
        #position_embedding[300,300,64] position_embedding_reshape[300*300,64]
        position_embedding_reshape = Reshape(position_embedding, shape=(-3, -2))

        position_feat_1_relu = relu(fc(position_embedding_reshape)) #fc(64,16)全连接层的参数就是公式5的WG, position_embedding_reshape[300*300,16)
        aff_weight = Reshape(position_feat_1_relu, shape=(-1, nongt_dim, fc_dim)) #[300,300,16]
        aff_weight = transpose(aff_weight, axes=(0, 2, 1)) #[300,16,300]

        q_data = fc(roi_feat) #fc(1024,1024) q_data[300,1024]  全连接层参数对应论文中公式4的WQ,roi_feat对应公式4的fA
        q_data_batch = Reshape(q_data, shape=(-1, group, dim_group[0])) #[300,16,64]
        q_data_batch = transpose(q_data_batch, axes=(1, 0, 2)) #[16,300,64]

        k_data = fc(nongt_roi_feat) #nongt_roi_feat其实就是roi_feat, fc(1024,1024) k_data[300,1024]  全连接层参数对应论文中公式4的WK,nongt_roi_feat对应公式4中的fA
        k_data_batch = Reshape(k_data, shape=(-1, group, dim_group[1])) #[300,16,64]
        k_data_batch = transpose(k_data_batch, axes=(1, 0, 2)) #[16,300,64]

        v_data = nongt_roi_feat #[300,1024]

        aff = batch_dot(lhs=q_data_batch, rhs=k_data_batch, transpose_a=False, transpose_b=True) #[16,300,300] batch_dot操作就是论文中公式4的dot
        aff_scale = (1.0 / math.sqrt(float(dim_group[1]))) * aff #公式4中的除法
        aff_scale = transpose(aff_scale, axes=(1, 0, 2)) #[300,16,300] aff_scale就是论文中公式4的结果:wA

        weighted_aff = log(maximum(left=aff_weight, right=1e-6)) + aff_scale #maximum得到的结果为公式5得到的,log(wG)+wA [300,16,300]
        aff_softmax = softmax(data=weighted_aff, axis=2)) #e^(log(wG)+wA)=wG*e^(wA) softmax实现论文中公式3的操作,axis设置为2表示在维度2上进行归一化
        aff_softmax_reshape = Reshape(aff_softmax, shape=(-3, -2)) #[300*16,300] 对应论文中公式3的w

        output_t = dot(lhs=aff_softmax_reshape, rhs=v_data) #v_data[300,1024] 对应论文中公式2的w和fA相乘的结果 [300*16,1024]
        output_t = Reshape(output_t, shape=(-1, fc_dim * feat_dim, 1, 1)) #[300,16*1024,1,1]


        #linear_out就是论文中公式2的fR。注意这里的卷积层有个num_group参数,
        # group数量设置为fc_dim,默认是16,对应论文中的Nr参数,因此论文中公式6的concat操
        # 作已经在这个卷积层中通过group操作实现了。
        linear_out = Convolution(data=output_t, kernel=(1, 1), num_filter=1024, num_group=16) #[300,1024,1,1] 卷积层的参数对应论文中公式2的WV
        output = Reshape(linear_out, shape=(0, 0)) #[300,1024]

        return output

对于最后取代NMS的duplicate removal network,其实就是用前面得到的score,然后relation module也给出的一个score,两个相乘得到最终score。用最终的score进行一个“二分类”来决定每个box是否应该保留。

 文章中进行了详细的实验在证明设计的relation module的有效性,而且不是由于加深了网络的深度宽度,参数量带来的效果,同时对比了设计的duplicate removal network与nms soft-nms的效果。

其他:

  • 结合位置信息与表观特征进行attention来找到不同proposal之间的关系,提高了目标检测的性能
  • 这种找不同物体之间关系的思想感觉可以用到视频目标检测中,去找到多帧之间物体之间的联系,但是怎么弄值得思考
  • 对于这个relation module具体学到了什么,其实解释不清楚,引用文中Conclusions的描述it is not clear what is learnt in the relation module, especially when multiple ones are stacked. our understanding of how relation module works is preliminary and left as future work.附上一张文中用来试图解释relation module学习内容的图,The left example suggests that several objects overlapping on the same ground truth (bicycle) contribute to the centering object. The right example suggests that the person contributes to the glove

 

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值