RefineDet算法源码 -训练代码解析

个人微信公众号:AI研习图书馆,欢迎关注~

深度学习知识及资源分享,学习交流,共同进步~

RefineDet算法系列文章

相关链接

  1. 论文地址
  2. 项目地址
  3. RefineDet文章翻译
  4. RefineDet训练代码解析
  5. RefineDet网络结构解析
  6. 目标检测-RefineDet实现详细步骤
  7. RefineDet论文总结
  8. 目标检测-RefineDet训练脚本解析
  9. 目标检测-RefineDet算法检测部分网络解析
  10. RefineDet:C++测试代码

1. 引言

论文:Single-Shot Refinement Neural Network for Object Detection

论文链接:https://arxiv.org/abs/1711.06897

代码链接:https://github.com/sfzhang15/RefineDet

2. RefineDet算法训练脚本解析

关于RefineDet算法总结内容可以移步我的其它博客:论文笔记栏目。

2.1 介绍

RefineDet算法是SSD算法的升级版本,所以大部分的代码也是基于SSD的开源代码来修改的。

SSD开源代码参考链接:
https://github.com/weiliu89/caffe/tree/ssd

RefineDet主要包含anchor refinement module (ARM) 、object detection module (ODM)、transfer connection block (TCB)3个部分。

ARM部分可以直接用SSD代码,只不过将分类支路的类别数由object数量+1修改成2,类似RPN网络,目的是得到更好的初始bbox。

ODM部分也可以基于SSD代码做修改,主要是原本采用的default box用ARM生成的bbox代替,剩下的分类和回归支路与SSD一样。

TCB部分则通过一些卷积层和反卷积层即可实现。

这篇博客介绍RefineDet算法的训练和宏观上的网络结构构造,以主网络为ResNet101,数据集为COCO的训练脚本为例。

代码所在位置:~RefineDet/examples/refinedet/ResNet101_COCO_320.py

2.2 代码具体解析
from __future__ import print_function
import sys
sys.path.append("./python")
import caffe
from caffe.model_libs import *
from google.protobuf import text_format

import math
import os
import shutil
import stat
import subprocess

# Add extra layers on top of a "base" network (e.g. VGGNet or ResNet).
# AddExtraLayers函数是网络结构构造中比较重要的函数,主要实现的就是论文中的
# transfer connection block (TCB)部分,也就是类似FPN算法的特征融合操作。
def AddExtraLayers(net, arm_source_layers=[], use_batchnorm=True):
    use_relu = True

    # Add additional convolutional layers.
    # 320/32: 10 x 10
# 传进来的net变量就是ResNet101(最后一层是'res5c_relu')
    last_layer = net.keys()[-1]
    from_layer = last_layer

    # 320/64: 5 x 5
# 在ResNet101后面再添加一个residual block,为的是获得stride=64的feature map。
    ResBody(net, from_layer, '6', out2a=128, out2b=128, out2c=512, stride=2, use_branch1=True)


# arm_source_layers传进来是['res3b3_relu', 'res4b22_relu', 'res5c_relu', 
# 'res6_relu'],reverse的目的是从最后一层开始执行transfer connection block操作。
    arm_source_layers.reverse()
    num_p = 6
    for index, layer in enumerate(arm_source_layers):
        from_layer = layer
        out_layer = "TL{}_{}".format(num_p, 1)
        ConvBNLayer(net, from_layer, out_layer, use_batchnorm, use_relu, 256, 3, 1, 1)

# num_p等于6是最后一层的处理,比较特殊,因为没有更上层的反卷积输入,最后得到的输出用P6表示。
        if num_p == 6:
            from_layer = out_layer
            out_layer = "TL{}_{}".format(num_p, 2)
            ConvBNLayer(net, from_layer, out_layer, use_batchnorm, use_relu, 256, 3, 1, 1)

            from_layer = out_layer
            out_layer = "P{}".format(num_p)
            ConvBNLayer(net, from_layer, out_layer, use_batchnorm, use_relu, 256, 3, 1, 1)
# 其他情况下执行的就是文章中的transfer connection block操作,
# 包含两个卷积层、一个反卷积操作和一个求和操作,最后得到的输出用P5、P4、P3表示。
        else:
            from_layer = out_layer
            out_layer = "TL{}_{}".format(num_p, 2)
            ConvBNLayer(net, from_layer, out_layer, use_batchnorm, False, 256, 3, 1, 1)

            from_layer = "P{}".format(num_p+1)
            out_layer = "P{}-up".format(num_p+1)
            DeconvBNLayer(net, from_layer, out_layer, use_batchnorm, False, 256, 2, 0, 2)

            from_layer = ["TL{}_{}".format(num_p, 2), "P{}-up".format(num_p+1)]
            out_layer = "Elt{}".format(num_p)
            EltwiseLayer(net, from_layer, out_layer)
            relu_name = '{}_relu'.format(out_layer)
            net[relu_name] = L.ReLU(net[out_layer], in_place=True)
            out_layer = relu_name

            from_layer = out_layer
            out_layer = "P{}".format(num_p)
            ConvBNLayer(net, from_layer, out_layer, use_batchnorm, use_relu, 256, 3, 1, 1)
        num_p = num_p - 1

    return net

### Modify the following parameters accordingly ###
# The directory which contains the caffe code.
# We assume you are running the script at the CAFFE_ROOT.
caffe_root = os.getcwd()

# Set true if you want to start training right after generating all files.
run_soon = True
# Set true if you want to load from most recently saved snapshot.
# Otherwise, we will load from the pretrain_model defined below.
resume_training = True
# If true, Remove old model files.
remove_old_models = False

# The database file for training data. Created by data/coco/create_data.sh
train_data = "examples/coco/coco_train_lmdb"
# The database file for testing data. Created by data/coco/create_data.sh
test_data = "examples/coco/coco_minival_lmdb"
<
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI研习图书馆

您的鼓励将是我创作的最大动力~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值