接上文,处理好输入图片后即可输入ENet的网络模型进行训练。
#Create the model inference
with slim.arg_scope(ENet_arg_scope(weight_decay=weight_decay)):
logits, probabilities = ENet(images,
num_classes,
batch_size=batch_size,
is_training=True,
reuse=None,
num_initial_blocks=num_initial_blocks,
stage_two_repeat=stage_two_repeat,
skip_connections=skip_connections)
其中,slim.arg_scope是对函数进行修饰,修改已经定义函数中的某个参数值。在这里,修改了ENet_arg_scope函数中的weight_decay值为我们定义的值。而ENet_arg_scope函数又长下面这个模样:
def ENet_arg_scope(weight_decay=2e-4,
batch_norm_decay=0.1,
batch_norm_epsilon=0.001):
'''
The arg scope for enet model. The weight decay is 2e-4 as seen in the paper.
Batch_norm decay is 0.1 (momentum 0.1) according to official implementation.
INPUTS:
- weight_decay(float): the weight decay for weights variables in conv2d and separable conv2d
- batch_norm_decay(float): decay for the moving average of batch_norm momentums.
- batch_norm_epsilon(float): small float added to variance to avoid dividing by zero.
OUTPUTS:
- scope(arg_scope): a tf-slim arg_scope with the parameters needed for xception.
'''
# Set weight_decay for weights in conv2d and separable_conv2d layers.
with slim.arg_scope([slim.conv2d], # 使用slim.arg_scope对 slim.conv2d函数进行修饰,设置默认参数
weights_regularizer=slim.l2_regularizer(weight_decay), # 修改了conv2d的weights_regularizer l2正则化
biases_regularizer=slim.l2_regularizer(weight_decay)): # 修改了conv2d的biases_regularizer l2正则化
# Set parameters for batch_norm.
with slim.arg_scope([slim.batch_norm], # 同上 设置batchnorm的参数
decay=batch_norm_decay,
epsilon=batch_norm_epsilon) as scope:
return scope
这个函数通过再次嵌套slim.arg_scope函数来修改slim.conv2d函数的参数,修改conv2d权重及偏置正则化事l2正则化的权重衰减值。同时修改了slim.batch_norm函数中decay及epsilon的值。梳理一下就是,首先使用slim.arg_scope修改ENet_arg_scope的传入参数,然ENet_arg_scope内部再调用slim.arg_scope来修改slim.conv2d和slim.batch_norm的参数。
随后,调整好了ENet的参数,我们就可以把图像输入ENet进行训练。
logits, probabilities = ENet(images, # 输入训练图像
num_classes,
batch_size=batch_size,
is_training=True, # 训练过程是否使用PReLu 和 batch normalization
reuse=None,
num_initial_blocks=num_initial_blocks,
stage_two_repeat=stage_two_repeat,
skip_connections=skip_connections)
ENet的传入参数为(原始图像,类别数,batch_size,是否使用PReLu和bn,reuse这个还不太理解,ENet中初始化模块的个数,ENet中第二部分模块的个数,是否使用跳跃连接)。
现在我们看一下ENet的网络结构,首先有个整体的概念(截图来自ENet原论文):
除了initial模块和fullconv模块,共计五大模块。且模块三中除了缺少一个下采样层外,同模块二完全相同。在ENet中,initial模块的定义如下:
图像输入后,一边通过13个3*3的卷积核,以步长2进行卷积;另一边使用2*2的核以步长2进行池化,因为输入为3通道的图片,池化后channel为3,将两个输出结合,得到channel为16的输出。
而bottleneck模块是由ResNet得到的启发,共定义了三种,分别是“普通卷积”,“空洞卷积”,“非对称卷积”三种。而bottleneck模块的总体定义为:
结合代码,对上述模块分别进行理解。
initial模块代码如下:
def initial_block(inputs, is_training=True, scope='initial_block'):
'''
The initial block for Enet has 2 branches: The convolution branch and Maxpool branch.
The conv branch has 13 layers, while the maxpool branch gives 3 layers corresponding to the RGB channels.
Both output layers are then concatenated to give an output of 16 layers.
NOTE: Does not need to store pooling indices since it won't be used later for the final upsampling.
INPUTS:
- inputs(Tensor): A 4D tensor of shape [batch_size, height, width, channels]
OUTPUTS:
- net_concatenated(Tensor): a 4D Tensor that contains the
'''
#Convolutional branch
net_conv = slim.conv2d(inputs, 13, [3,3], stride=2, activation_fn=None, scope=scope+'_conv') # 3x3卷积,13个卷积核,步长2
net_conv = slim.batch_norm(net_conv, is_training=is_training, fused=True, scope=scope+'_batchnorm')
net_conv = prelu(net_conv, scope=scope+'_prelu')
#Max pool branch
net_pool = slim.max_pool2d(inputs, [2,2], stride=2, scope=scope+'_max_pool')
#Concatenated output - does it matter max pool comes first or conv comes first? probably not.
net_concatenated = tf.concat([net_conv, net_pool], axis=3, name=scope+'_concat')
return net_concatenated
卷积分支中,首先将输入使用13个3*3的卷积核,按步长为2进行卷积,然后输出结果输入bn层做归一化处理,然后使用PReLU作为激活函数。
而PReLU的函数代码为:
def prelu(x, scope, decoder=False):
'''
Performs the parametric relu operation. This implementation is based on:
https://stackoverflow.com/questions/39975676/how-to-implement-prelu-activation-in-tensorflow
For the decoder portion, prelu becomes just a normal prelu
INPUTS:
- x(Tensor): a 4D Tensor that undergoes prelu
- scope(str): the string to name your prelu operation's alpha variable.
- decoder(bool): if True, prelu becomes a normal relu.
OUTPUTS:
- pos + neg / x (Tensor): gives prelu output only during training;