上一篇博客:SSD算法代码介绍(四):网络结构初探
介绍了网络主干结构构建、损失函数定义等代码,但是网络的详细构建内容是在common.py脚本中的两个重要函数:multi_layer_feature和multibox_layer进行的,比如基于分类网络新增加的一些层、预测层的定义等,因此接下来这篇博客介绍common.py脚本。
common.py这个脚本主要包括conv_act_layer,legacy_conv_act_layer,multi_layer_feature,multibox_layer四个函数,因为在symbol_builder.py中先后调用了multi_layer_feature和multibox_layer两个函数,前面一个主要用来生成特征融合的几个层,后面一个主要用来生成分类层,回归层,生成anchor层,比较重要,因此可以先看看这两个函数。
import mxnet as mx
import numpy as np
def multi_layer_feature(body, from_layers, num_filters, strides, pads, min_filter=128):
"""Wrapper function to extract features from base network, attaching extra
layers and SSD specific layers
Parameters
----------
from_layers : list of str
feature extraction layers, use '' for add extra layers
For example:
from_layers = ['relu4_3', 'fc7', '', '', '', '']
which means extract feature from relu4_3 and fc7, adding 4 extra layers
on top of fc7
num_filters : list of int
number of filters for extra layers, you can use -1 for extracted features,
however, if normalization and scale is applied, the number of filter for
that layer must be provided.
For example:
num_filters = [512, -1, 512, 256, 256, 256]
strides : list of int
strides for the 3x3 convolution appended, -1 can be used for extracted
feature layers
pads : list of int
paddings for the 3x3 convolution, -1 can be used for extracted layers
min_filter : int
minimum number of filters used in 1x1 convolution
Returns
-------
list of mx.Symbols
"""
# arguments check
assert len(from_layers) > 0
assert isinstance(from_layers[0], str) and len(from_layers[0].strip()) > 0
assert len(from_layers) == len(num_filters) == len(strides) == len(pads)
# 这个body是你导入的vgg16的symbol,get_internals函数是获取这个symbol的所有参数层的信息
internals = body.get_internals()
layers = []
for k, params in enumerate(zip(from_layers, num_filters, strides, pads)):
from_layer, num_filter, s, p = params
# from_layer大概是这样的['relu4_3','relu7','','','',''],因此如果是后面的空字符串,那么就会跳到else那一部分
if from_layer.strip():
# extract from base network
# 得到from_layer里面指定名称的层,然后append到layers这个列表中
layer = internals[from_layer.strip() + '_output']
layers.append(layer)
else: # 这一部分就是给原来的VGG16网络加上几个新的卷积操作
# attach from last feature layer
assert len(layers) > 0
assert num_filter > 0
layer = layers[-1] # 这一行表示现在要加的层的输入是上一层的输出
num_1x1 = max(min_filter, num_filter // 2)
# 一共添加了1*1和3*3两个卷积层,前面一个起减少计算量的作用,一般把这两个当做一个层,或者叫symbol