batch_normalization对于训练深度学习模型有非常重大的意义。但是这里有三个函数可供选择:
- tf.nn.batch_normalization
- tf.layers.batch_normalization
- tf.contrib.layers.batch_norm
其中第一种是最底层的实现,可以更灵活的对bn做对应的修改,剩下两种则是高级封装(对于数据的均值和方差采取滑动平均的方法来计算)。
我们都知道,在训练过程中需要把第2/3个函数的关键字training设置为True,而在测试时候则需要设置training=False.然而我们发现,如果仅这样操作,则得不到很好的准确率。这是因为第2/3个函数中均值和方差分别是从0和1初始化而来,然后采用滑动平均的方法逐渐更新。但是这个均值和方差不是变量,因此在保存模型的时候这个变量的值并不会被保存。所以需要在训练节点之前加入这段函数:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
这样,不断更新的均值和方差才会被保存。(否则,在测试时均值和方差又从0,1初始值开始预测)。
查看滑动平均的均值和方差的方法:
all_variable = tf.global_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, ckpt_path)
for variable in all_variable:
if "moving" in variable.name:
print(variable.name, variable.eval())
当然,如果想手写带有滑动平均的batch_norm可以这样:
def bn_layer(x, scope, is_training, epsilon=0.001, decay=0.99, reuse=None):
"""
Performs a batch normalization layer
Args:
x: input tensor
scope: scope name
is_training: python boolean value
epsilon: the variance epsilon - a small float number to avoid dividing by 0
decay: the moving average decay
Returns:
The ops of a batch normalization layer
"""
with tf.variable_scope(scope, reuse=reuse):
shape = x.get_shape().as_list()
# gamma: a trainable scale factor
gamma = tf.get_variable("gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=True)
# beta: a trainable shift value
beta = tf.get_variable("beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=True)
moving_avg = tf.get_variable("moving_avg", shape[-1], initializer=tf.constant_initializer(0.0), trainable=False)
moving_var = tf.get_variable("moving_var", shape[-1], initializer=tf.constant_initializer(1.0), trainable=False)
if is_training:
# tf.nn.moments == Calculate the mean and the variance of the tensor x
avg, var = tf.nn.moments(x, np.arange(len(shape)-1), keep_dims=True)
avg=tf.reshape(avg, [avg.shape.as_list()[-1]])
var=tf.reshape(var, [var.shape.as_list()[-1]])
#update_moving_avg = moving_averages.assign_moving_average(moving_avg, avg, decay)
update_moving_avg=tf.assign(moving_avg, moving_avg*decay+avg*(1-decay))
#update_moving_var = moving_averages.assign_moving_average(moving_var, var, decay)
update_moving_var=tf.assign(moving_var, moving_var*decay+var*(1-decay))
control_inputs = [update_moving_avg, update_moving_var]
else:
avg = moving_avg
var = moving_var
control_inputs = []
with tf.control_dependencies(control_inputs):
output = tf.nn.batch_normalization(x, avg, var, offset=beta, scale=gamma, variance_epsilon=epsilon)
return output
def bn_layer_top(x, scope, is_training, epsilon=0.001, decay=0.99):
"""
Returns a batch normalization layer that automatically switch between train and test phases based on the
tensor is_training
Args:
x: input tensor
scope: scope name
is_training: boolean tensor or variable
epsilon: epsilon parameter - see batch_norm_layer
decay: epsilon parameter - see batch_norm_layer
Returns:
The correct batch normalization layer based on the value of is_training
"""
#assert isinstance(is_training, (ops.Tensor, variables.Variable)) and is_training.dtype == tf.bool
return tf.cond(
is_training,
lambda: bn_layer(x=x, scope=scope, epsilon=epsilon, decay=decay, is_training=True, reuse=None),
lambda: bn_layer(x=x, scope=scope, epsilon=epsilon, decay=decay, is_training=False, reuse=True),
)
或者用更简洁的
import numpy as np
import tensorflow as tf
from tensorflow.python import control_flow_ops
def batch_norm(x, n_out, phase_train, scope='bn'):
"""
Batch normalization on convolutional maps.
Args:
x: Tensor, 4D BHWD input maps
n_out: integer, depth of input maps
phase_train: boolean tf.Varialbe, true indicates training phase
scope: string, variable scope
Return:
normed: batch-normalized maps
"""
with tf.variable_scope(scope):
beta = tf.Variable(tf.constant(0.0, shape=[n_out]),
name='beta', trainable=True)
gamma = tf.Variable(tf.constant(1.0, shape=[n_out]),
name='gamma', trainable=True)
batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.99)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
return normed
下面附上普通的bn
def batchnorm(inputs):
with tf.variable_scope("batchnorm"):
# this block looks like it has 3 inputs on the graph unless we do this
inputs = tf.identity(inputs)
channels = inputs.get_shape()[3]
offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02))
mean, variance = tf.nn.moments(inputs, axes=[0, 1, 2], keep_dims=False)
variance_epsilon = 1e-5
normalized = tf.nn.batch_normalization(inputs, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
return normalized