x = tf.nn.batch_normalization(x, mean, variance, beta, gama, BN_EPSILON)
x为输入数据,mean为批量数据x的均值,variance为批量数据x的方差(注意均值,方差为每一个维度求均值,方差),beta和gama分别为可学习的平移参数和缩放参数,BN_EPSILON防止方差为0(通常设为0.001)。
完整的bn函数如下
def bn(x, use_bn, is_training):
x_shape = x.get_shape()
params_shape = x_shape[-1:]
if not use_bn:
bias = _get_variable('bias', params_shape, initializer=tf.zeros_initializer)
retrun x + bias
axis = list(range(len(x_shape) - 1))
beta = _get_variable('beta', params_shape, initializer=tf.zeros_initializer)
gamma = _get_variable('gamma', params_shape, initializer=tf.ones_initializer)
moving_mean = _get_variable('moving_mean', params_shape, initializer=tf.zeros_initializer, trainable=False)
moving_variance = _get_variable(moving_variance, params_shape, initializer=tf.ones_initializer, trainable=False)
# these ops will only be performed when training.
mean, variance = tf.nn.moments(x, axis)
update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY)
update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
mean, variance = control_flow_ops.cond(is_training, lambda:(mean, variance), lambda:(moving_mean, moving_variance))
x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
return x
1.下面简单介绍下里面所用的函数:
tf.nn.moments(x, axes, shift=None, name=None, keep_dims=False)计算x的均值和方差(当x的维度为[batch, height, width, depth]时,对于global normalization,axes = [0, 1, 2], 此时mean和variance的维度为[depth],对于simple batch normalization, axes = [0], 此时mean和variance的维度为[height, width, depth])
moving_averages.assign_moving_average(variable, value, decay, zero_debias=True, name=None)计算变量的滑动平均值,更新后变量的值为variable * decay + value * (1 - decay), 在本例中decay = BN_DECAY = 0.9997
tf.add_to_collection(name, value)把变量放入一个集合,集合的关键字为name
control_flow_ops.cond(pred, true_fn=None, false_fn=None, strict=False, name=None, fn1=None, fn2=None)返回true_fn如果pred为True否则返回false_fn。本例中用以控制训练和测试使用的均值和方差。注意pred不能是python bool
2.接下来简单介绍怎么更新bn的参数:
batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
batchnorm_updates_op = tf.group(*batchnorm_updates)
train_op = tf.group(apply_gradient_op, batchnorm_updates_op)
tf.get_collection(key, scope=None):从关键字为key的集合中取出全部的变量,如果scope不为None,则取出该集合中包含scope变量名的变量。
tf.group(*inputs, **kwargs):将一些operation或者变量group起来。