Batch Normalization || 2. 在tensorflow 中的API

前言:

Batch Normalization || 1. 原理介绍
Batch Normalization || 2. 在tensorflow 中的API
Batch Normalization || 3. 在tensorflow 中bn的坑 —— Is_training、momentum的设置

1 BN在tensorflow中的api

tensorflow中batch normalization的实现主要有下面三个:

  • tf.nn.batch_normalization
  • tf.layers.batch_normalization
  • tf.contrib.layers.batch_norm

    封装程度逐个递进,建议使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,因为在tensorflow官网的解释比较详细。我平时多使用tf.layers.batch_normalization,因此下面的步骤都是基于这个。

1.1 tf.nn.batch_normalization(最底层的实现)

tf.nn.batch_normalization(
   x,
   mean,
   variance,
   offset,
   scale,
  variance_epsilon,
   name=None
)

该函数是一种最底层的实现方法,在实际使用时mean、variance、scale、offset等参数需要自己传递并更新,因此实际使用时,还需要自己对该函数进行封装。一般不推荐使用,但是对了解batch_norm的原理很有帮助。

封装使用的例子

import tensorflow as tf

def batch_norm(x, name_scope, training, epslilon=1e-3, decay=0.99):
   with tf.variable_scope(name_scope):
       size = x.get_shape().as_list()[-1]
       scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(0.1))
       offset = tf.get_variable('offset', [size])
       moving_mean = tf.get_variable('moving_mean', [size], initializer=tf.zeros_initializer, trainable=False)
       moving_var = tf.get_variable('moving_var', [size], initializer=tf.ones_initializer(), trainable=False)

       batch_mean, batch_var = tf.nn.moments(x, list(range(len(x.get_shape())-1)))
       train_mean_op = tf.assign(moving_mean, moving_mean * decay + batch_mean * (1 - decay))
       train_var_op = tf.assign(moving_var, moving_var * decay + batch_var * (1-decay))

       def is_training():
           with tf.control_dependencies([train_mean_op, train_var_op]):
               mean, var = batch_mean, batch_var
               return tf.nn.batch_normalization(x, mean, var, offset, scale, epslilon)
       def is_test():
           mean, var = moving_mean, moving_var
           return tf.nn.batch_normalization(x, mean, var, offset, scale, epslilon)
       return tf.cond(training, is_training(), is_test())
       


在batch_norm中,首先计算了x在通道上的mean和var,然后将moving_mean和moving_var进行更新,并根据是训练阶段还是测试阶段选择当前批次的mean和var还是统计的mean和var作为tf.nn.batch_normalization的scale和offset。

注意:

  • 保存网络模型时,直接保存所有的变量即可。保存的模型中bn,已经保存了统计后的滑动平均和滑动方差
  • 其中tf.assign()和tf.control_denpendencies()的配合使用,后面讲解

1.2 tf.layers.batch_normalization

tf.layers.batch_normalization(
   inputs,
   axis=-1,
   momentum=0.99,
   epsilon=0.001,
   center=True,
   scale=True,
   beta_initializer=tf.zeros_initializer(),
   gamma_initializer=tf.ones_initializer(),
   moving_mean_initializer=tf.zeros_initializer(),
   moving_variance_initializer=tf.ones_initializer(),
   beta_regularizer=None,
   gamma_regularizer=None,
   beta_constraint=None,
   gamma_constraint=None,
   training=False,
   trainable=True,
   name=None,
   reuse=None,
   renorm=False,
   renorm_clipping=None,
   renorm_momentum=0.99,
   fused=None,
   virtual_batch_size=None,
   adjustment=None
)


我们实际使用的时候:

x_norm = tf.layers.batch_normalization(x, training=training)
# ...

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
   train_op = optimizer.minimize(loss)

这个操作的添加,是为了每次训练中,进行更新滑动平均和滑动方差。否则最终保存的bn参数是没有进行更新的。

注意

  • 对于update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)后面细讲。

1.3 tf.contrib.layers.batch_norm

下面的两个bn在源码上的定义是同一个函数

import tensorflow.contrib.layers as layers
layers.batch_norm

import tensorflow.contrib.slim as slim
slim.batch_norm


进入函数内部,可以看到参数的传入:

tf.contrib.layers.batch_norm(
   inputs,
   decay=0.999,
   center=True,
   scale=False,
   epsilon=0.001,
   activation_fn=None,
   param_initializers=None,
   param_regularizers=None,
   updates_collections=tf.GraphKeys.UPDATE_OPS,
   is_training=True,
   reuse=None,
   variables_collections=None,
   outputs_collections=None,
   trainable=True,
   batch_weights=None,
   fused=None,
   data_format=DATA_FORMAT_NHWC,
   zero_debias_moving_mean=False,
   scope=None,
   renorm=False,
   renorm_clipping=None,
   renorm_decay=0.99,
   adjustment=None
)
......


实际使用时,训练脚本需要添加:

   update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
   with tf.control_dependencies(update_ops):
     train_op = optimizer.minimize(loss)

  • 在使用方面上,注重看到,在网络设置优化器训练时,需要加上tf.get_collection(tf.GraphKeys.UPDATE_OPS)with tf.control_dependencies(update_ops)的操作,与tf.layers.batch_normalization一样。
  • 原理方面上,在offset的设置上,tf.layers.batch_normalization和tf.contrib.layers.batch_norm都为True;scale的设置上,前者为True后者为False。也就是tf.contrib.layers.batch_norm中,默认不对处理后的input进行线性缩放,只施加的偏移。
    (有相关的说明和测试,bn中的scale是可有可无的。不会影响训练的效果。所以再slim中,bn中也就没有添加scale参数进行学习)

2 其他

在3.3中,遗留的问题

2.1 tf.control_dependencies

import tensorflow as tf
a_1 = tf.Variable(1)
b_1 = tf.Variable(2)
update_op = tf.assign(a_1, 10)
add = tf.add(a_1, b_1)
with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   print(sess.run(add))        # 3
   upda = sess.run(update_op)  # 运行过后,update_op的值等于a_1的值   
   print(upda, sess.run(a_1))  # 10, 10
   print(sess.run(add))        # 12

a_2 = tf.Variable(1)
b_2 = tf.Variable(2)
update_op = tf.assign(a_2, 10)
with tf.control_dependencies([update_op]):
   add_with_dependencies = tf.add(a_2, b_2)  # 执行add_with_dependencies前一定先执行update_op
with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   print(sess.run(add_with_dependencies)  # 12


注意:

  • tf.control_dependencies
  • tf.assign(a, b) 返回是一个操作,该操作是将b赋给a。

在正常操作中,是不会经过update_op操作的。但是tf.assign配合使用tf.control_dependencies函数,就是在运行update_op的情况下再进行add的操作(相当于做了一个限定)


2.2 tf.GraphKeys.UPDATE_OPS

举例子说明:

import tensorflow as tf

is_traing = tf.placeholder(dtype=tf.bool)
input = tf.ones([1, 2, 2, 3])
output = tf.layers.batch_normalization(input, training=is_traing)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(update_ops)
# with tf.control_dependencies(update_ops):
   # train_op = optimizer.minimize(loss)

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   saver = tf.train.Saver()
   saver.save(sess, "batch_norm_layer/Model")
   
输出:
[<tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(3,) dtype=float32_ref>, 
<tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(3,) dtype=float32_ref>]


可以看到输出的即为两个batch_normalization中更新mean和var的操作,需要保证它们在train_op前完成。

  • tf.GraphKeys.UPDATE_OPS:
    tensorflow的计算图中内置的一个集合,其中会保存一些需要训练操作之前完成的操作,比如:bn中更新后的mean和var。
    这两个操作会自动被加入tf.GraphKeys.UPDATE_OPS这个集合。
  • tf.get_collection作用:把tf.GraphKeys.UPDATE_OPS这个集合添加到当前图中。
  • with tf.control_dependencies(update_ops):在进行网络优化前,保证bn中的mean和var的更新

3 代码中使用

3.1 训练时

(已经明确bn中:training=True,会使用batch内的mean和var。training=False,会使用更新后的滑动均值和滑动方差)
代码中设置

  • 输入参数training=True

  • 参数的更新:计算loss时,要添加以下代码(即添加update_ops到最后的train_op中)。这样才能计算μ和σ的滑动平均(测试时会用到)

    
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss)
    
    
  • 参数的保存:保存模型时不能只保存训练参数,bn中的均值方差并不是训练参数,但是必须同时保存。所以解决方法

    • 直接保存所有参数,
    • 额外保存参数(BN均值方差)
    # 方法(1)
    var_list = tf.global_variables()
    saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
    
    # 方法(2)
    var_list = tf.trainable_variables()
    g_list = tf.global_variables()
    bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
    bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
    var_list += bn_moving_vars
    saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
    
    

    建议第一种,因为训练过程中我们会用到其他不可训练的参数。比如,在设置优化器时
    train_op = optimizer.minimize(total_loss, global_step, colocate_gradients_with_ops=True)
    当中的global_step也是不可训练参数,也需要保存下来。当训练中断后,可以很好的记录训练的迭代次数,让网络的学习率等正确更新。


3.2 预测时

测试时需要注意一点,输入参数training=False,其他正常导入模型,正常测试即可

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值