论文:Single-Shot Refinement Neural Network for Object Detection
论文链接:https://arxiv.org/abs/1711.06897
代码链接:https://github.com/sfzhang15/RefineDet
关于RefineDet算法内容可以先看看博客:RefineDet论文笔记。
RefineDet算法是SSD算法的升级版本,所以大部分的代码也是基于SSD的开源代码来修改的。SSD开源代码参考链接:https://github.com/weiliu89/caffe/tree/ssd。RefineDet主要包含anchor refinement module (ARM) 、object detection module (ODM)、transfer connection block (TCB)3个部分,ARM部分可以直接用SSD代码,只不过将分类支路的类别数由object数量+1修改成2,类似RPN网络,目的是得到更好的初始bbox。ODM部分也可以基于SSD代码做修改,主要是原本采用的default box用ARM生成的bbox代替,剩下的分类和回归支路与SSD一样。TCB部分则通过一些卷积层和反卷积层即可实现。
在博客:RefineDet算法源码 (一)训练脚本中介绍了训练RefineDet算法的代码,其中包含宏观上的网络结构构建,并未涉及细节内容。因此这篇博客介绍RefineDet算法的具体网络结构构造细节,代码所在路径:~RefineDet/python/caffe/model_libs.py脚本的CreateRefineDetHead函数。
'''
CreateRefineDetHead函数是本文关于网络结构构造的重点,这部分代码也是在原来SSD的CreateMultiBoxHead函数
基础上修改得到的,可以看作是将原来SSD的CreateMultiBoxHead函数内容实现了两遍,一遍用来实现ARM部分,
另一边用来实现ORM部分。from_layers和from_layers2是两个重点输入,
分别对应论文中Figure1的ARM和OBM两部分输出。因此这两遍实现除了输入不同外,另一个不同是ARM部分
是类似RPN网络的bbox回归和二分类,而ORM部分是类似SSD检测网络的bbox回归和object分类。
'''
def CreateRefineDetHead(net, data_layer="data", num_classes=[], from_layers=[], from_layers2=[], normalizations=[], use_batchnorm=True, lr_mult=1, min_sizes=[], max_sizes=[], prior_variance = [0.1],aspect_ratios=[], steps=[], img_height=0, img_width=0, share_location=True, flip=True, clip=True, offset=0.5, inter_layer_depth=[], kernel_size=1, pad=0, conf_postfix='', loc_postfix='', **bn_param):
assert num_classes, "must provide num_classes"
assert num_classes > 0, "num_classes must be positive number"
if normalizations:
assert len(from_layers) == len(normalizations), "from_layers and normalizations should have same length"
assert len(from_layers) == len(min_sizes), "from_layers and min_sizes should have same length"
if max_sizes:
assert len(from_layers) == len(max_sizes), "from_layers and max_sizes should have same length"
if aspect_ratios:
assert len(from_layers) == len(aspect_ratios), "from_layers and aspect_ratios should have same length"
if steps:
assert len(from_layers) == len(steps), "from_layers and steps should have same length"
net_layers = net.keys()
assert data_layer in net_layers, "data_layer is not in net's layers"
if inter_layer_depth:
assert len(