RefineDet算法源码(二)网络结构

论文:Single-Shot Refinement Neural Network for Object Detection
论文链接:https://arxiv.org/abs/1711.06897
代码链接:https://github.com/sfzhang15/RefineDet

关于RefineDet算法内容可以先看看博客:RefineDet论文笔记

RefineDet算法是SSD算法的升级版本,所以大部分的代码也是基于SSD的开源代码来修改的。SSD开源代码参考链接:https://github.com/weiliu89/caffe/tree/ssdRefineDet主要包含anchor refinement module (ARM) 、object detection module (ODM)、transfer connection block (TCB)3个部分,ARM部分可以直接用SSD代码,只不过将分类支路的类别数由object数量+1修改成2,类似RPN网络,目的是得到更好的初始bbox。ODM部分也可以基于SSD代码做修改,主要是原本采用的default box用ARM生成的bbox代替,剩下的分类和回归支路与SSD一样。TCB部分则通过一些卷积层和反卷积层即可实现。

在博客:RefineDet算法源码 (一)训练脚本中介绍了训练RefineDet算法的代码,其中包含宏观上的网络结构构建,并未涉及细节内容。因此这篇博客介绍RefineDet算法的具体网络结构构造细节,代码所在路径:~RefineDet/python/caffe/model_libs.py脚本的CreateRefineDetHead函数。

'''
CreateRefineDetHead函数是本文关于网络结构构造的重点,这部分代码也是在原来SSD的CreateMultiBoxHead函数
基础上修改得到的,可以看作是将原来SSD的CreateMultiBoxHead函数内容实现了两遍,一遍用来实现ARM部分,
另一边用来实现ORM部分。from_layers和from_layers2是两个重点输入,
分别对应论文中Figure1的ARM和OBM两部分输出。因此这两遍实现除了输入不同外,另一个不同是ARM部分
是类似RPN网络的bbox回归和二分类,而ORM部分是类似SSD检测网络的bbox回归和object分类。
'''
def CreateRefineDetHead(net, data_layer="data", num_classes=[], from_layers=[], from_layers2=[], normalizations=[], use_batchnorm=True, lr_mult=1, min_sizes=[], max_sizes=[], prior_variance = [0.1],aspect_ratios=[], steps=[], img_height=0, img_width=0, share_location=True, flip=True, clip=True, offset=0.5, inter_layer_depth=[], kernel_size=1, pad=0, conf_postfix='', loc_postfix='', **bn_param):
    assert num_classes, "must provide num_classes"
    assert num_classes > 0, "num_classes must be positive number"
    if normalizations:
        assert len(from_layers) == len(normalizations), "from_layers and normalizations should have same length"
    assert len(from_layers) == len(min_sizes), "from_layers and min_sizes should have same length"
    if max_sizes:
        assert len(from_layers) == len(max_sizes), "from_layers and max_sizes should have same length"
    if aspect_ratios:
        assert len(from_layers) == len(aspect_ratios), "from_layers and aspect_ratios should have same length"
    if steps:
        assert len(from_layers) == len(steps), "from_layers and steps should have same length"
    net_layers = net.keys()
    assert data_layer in net_layers, "data_layer is not in net's layers"
    if inter_layer_depth:
        assert len(
  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值