【学习笔记】Tensorflow-ENet代码学习(二)

本文详细介绍了如何在TensorFlow中实现ENet网络模型,包括ENet的参数设置、arg_scope的使用以及ENet的各个模块,如initial模块、bottleneck模块的详细解析,涵盖了下采样、空洞卷积和非对称卷积等关键操作。通过对代码的解读,深入理解ENet的网络结构和工作原理。
摘要由CSDN通过智能技术生成

接上文,处理好输入图片后即可输入ENet的网络模型进行训练。

        #Create the model inference
        with slim.arg_scope(ENet_arg_scope(weight_decay=weight_decay)): 
            logits, probabilities = ENet(images, 
                                         num_classes,
                                         batch_size=batch_size,
                                         is_training=True, 
                                         reuse=None,
                                         num_initial_blocks=num_initial_blocks,
                                         stage_two_repeat=stage_two_repeat,
                                         skip_connections=skip_connections)

其中,slim.arg_scope是对函数进行修饰,修改已经定义函数中的某个参数值。在这里,修改了ENet_arg_scope函数中的weight_decay值为我们定义的值。而ENet_arg_scope函数又长下面这个模样:

def ENet_arg_scope(weight_decay=2e-4,
                   batch_norm_decay=0.1,
                   batch_norm_epsilon=0.001):
  '''
  The arg scope for enet model. The weight decay is 2e-4 as seen in the paper.
  Batch_norm decay is 0.1 (momentum 0.1) according to official implementation.

  INPUTS:
  - weight_decay(float): the weight decay for weights variables in conv2d and separable conv2d
  - batch_norm_decay(float): decay for the moving average of batch_norm momentums.
  - batch_norm_epsilon(float): small float added to variance to avoid dividing by zero.

  OUTPUTS:
  - scope(arg_scope): a tf-slim arg_scope with the parameters needed for xception.
  '''
  # Set weight_decay for weights in conv2d and separable_conv2d layers.
  with slim.arg_scope([slim.conv2d], # 使用slim.arg_scope对 slim.conv2d函数进行修饰,设置默认参数
                      weights_regularizer=slim.l2_regularizer(weight_decay), # 修改了conv2d的weights_regularizer l2正则化
                      biases_regularizer=slim.l2_regularizer(weight_decay)): # 修改了conv2d的biases_regularizer l2正则化

    # Set parameters for batch_norm.
    with slim.arg_scope([slim.batch_norm], # 同上 设置batchnorm的参数 
                        decay=batch_norm_decay,
                        epsilon=batch_norm_epsilon) as scope:
      return scope

这个函数通过再次嵌套slim.arg_scope函数来修改slim.conv2d函数的参数,修改conv2d权重及偏置正则化事l2正则化的权重衰减值。同时修改了slim.batch_norm函数中decay及epsilon的值。梳理一下就是,首先使用slim.arg_scope修改ENet_arg_scope的传入参数,然ENet_arg_scope内部再调用slim.arg_scope来修改slim.conv2d和slim.batch_norm的参数。

随后,调整好了ENet的参数,我们就可以把图像输入ENet进行训练。

            logits, probabilities = ENet(images, # 输入训练图像
                                         num_classes,
                                         batch_size=batch_size,
                                         is_training=True, # 训练过程是否使用PReLu 和 batch normalization
                                         reuse=None,
                                         num_initial_blocks=num_initial_blocks,
                                         stage_two_repeat=stage_two_repeat,
                                         skip_connections=skip_connections)

ENet的传入参数为(原始图像,类别数,batch_size,是否使用PReLu和bn,reuse这个还不太理解,ENet中初始化模块的个数,ENet中第二部分模块的个数,是否使用跳跃连接)。

现在我们看一下ENet的网络结构,首先有个整体的概念(截图来自ENet原论文):

除了initial模块和fullconv模块,共计五大模块。且模块三中除了缺少一个下采样层外,同模块二完全相同。在ENet中,initial模块的定义如下:

图像输入后,一边通过13个3*3的卷积核,以步长2进行卷积;另一边使用2*2的核以步长2进行池化,因为输入为3通道的图片,池化后channel为3,将两个输出结合,得到channel为16的输出。

而bottleneck模块是由ResNet得到的启发,共定义了三种,分别是“普通卷积”,“空洞卷积”,“非对称卷积”三种。而bottleneck模块的总体定义为:

结合代码,对上述模块分别进行理解。

initial模块代码如下:

def initial_block(inputs, is_training=True, scope='initial_block'):
    '''
    The initial block for Enet has 2 branches: The convolution branch and Maxpool branch.

    The conv branch has 13 layers, while the maxpool branch gives 3 layers corresponding to the RGB channels.
    Both output layers are then concatenated to give an output of 16 layers.

    NOTE: Does not need to store pooling indices since it won't be used later for the final upsampling.

    INPUTS:
    - inputs(Tensor): A 4D tensor of shape [batch_size, height, width, channels]

    OUTPUTS:
    - net_concatenated(Tensor): a 4D Tensor that contains the 
    '''
    #Convolutional branch
    net_conv = slim.conv2d(inputs, 13, [3,3], stride=2, activation_fn=None, scope=scope+'_conv') # 3x3卷积,13个卷积核,步长2
    net_conv = slim.batch_norm(net_conv, is_training=is_training, fused=True, scope=scope+'_batchnorm')
    net_conv = prelu(net_conv, scope=scope+'_prelu')

    #Max pool branch
    net_pool = slim.max_pool2d(inputs, [2,2], stride=2, scope=scope+'_max_pool')

    #Concatenated output - does it matter max pool comes first or conv comes first? probably not.
    net_concatenated = tf.concat([net_conv, net_pool], axis=3, name=scope+'_concat')
    return net_concatenated

卷积分支中,首先将输入使用13个3*3的卷积核,按步长为2进行卷积,然后输出结果输入bn层做归一化处理,然后使用PReLU作为激活函数。

而PReLU的函数代码为:

def prelu(x, scope, decoder=False):
    '''
    Performs the parametric relu operation. This implementation is based on:
    https://stackoverflow.com/questions/39975676/how-to-implement-prelu-activation-in-tensorflow

    For the decoder portion, prelu becomes just a normal prelu

    INPUTS:
    - x(Tensor): a 4D Tensor that undergoes prelu
    - scope(str): the string to name your prelu operation's alpha variable.
    - decoder(bool): if True, prelu becomes a normal relu.

    OUTPUTS:
    - pos + neg / x (Tensor): gives prelu output only during training;
  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值