理论部分结合:https://www.zhihu.com/question/38102762
batch_norm_template 函数实现
def batch_norm_template(inputs, is_training, scope, moments_dims, bn_decay):
""" Batch normalization on convolutional maps and beyond...
Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
Args:
inputs: Tensor, k-D input ... x C could be BC or BHWC or BDHWC
is_training: boolean tf.Varialbe, true indicates training phase
scope: string, variable scope
moments_dims: a list of ints, indicating dimensions for moments calculation
bn_decay: float or float tensor variable, controling moving average weight
Return:
normed: batch-normalized maps
"""
with tf.variable_scope(scope) as sc:
num_channels = inputs.get_shape()[-1].value
beta = tf.Variable(tf.constant(0.0, shape=[num_channels]),
name='beta', trainable=True)
gamma = tf.Variable(tf.constant(1.0, shape=[num_channels]),
name='gamma', trainable=True)
batch_mean, batch_var = tf.nn.moments(inputs, moments_dims, name='moments')
decay = bn_decay if bn_decay is not None else 0.9
ema = tf.train.ExponentialMovingAverage(decay=decay)
# Operator that maintains moving averages of variables.
ema_apply_op = tf.cond(is_training,
lambda: ema.apply([batch_mean, batch_var]),
lambda: tf.no_op())
# Update moving average and return current batch's avg and var.
def mean_var_with_update():
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
# ema.average returns the Variable holding the average of var.
mean, var = tf.cond(is_training,
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_normalization(inputs, mean, var, beta, gamma, 1e-3)
return normed
分段解释:
with tf.variable_scope(scope) as sc:
num_channels = inputs.get_shape()[-1].value
beta = tf.Variable(tf.constant(0.0, shape=[num_channels]),
name='beta', trainable=True)
gamma = tf.Variable(tf.constant(1.0, shape=[num_channels]),
name='gamma', trainable=True)
初始化 beta 与 gamma,这两个参数为了使每一层norm 之后的数据还能恢复之前的分布而存在,是网络在训练中学习得到的。
batch_mean, batch_var = tf.nn.moments(inputs, moments_dims, name='moments')
计算每个batch的平均值和方差,moments_dims 指明需要计算的dim。
decay = bn_decay if bn_decay is not None else 0.9
ema = tf.train.ExponentialMovingAverage(decay=decay)
# Operator that maintains moving averages of variables.
ema_apply_op = tf.cond(is_training,
lambda: ema.apply([batch_mean, batch_var]),
lambda: tf.no_op())
每训练一个batch就用滑动平均值来更新整个训练集的平均值与方差,滑动平均值的原理如下:
实际运用中,衰减率decay 一般会设置为十分接近 1 的常数(0.99或0.999),或者动态可变的值。
tf.cond 函数指明,若处在训练中,则对平均值与方差更新一次滑动平均,否则什么也不干(tf.no_op())
# Update moving average and return current batch's avg and var.
def mean_var_with_update():
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
tf.control_dependencies 与 tf.identity 一起使用,进行流控制。即只有ema_apply_op 运行完才能返回更新后的滑动平均值。此处必须用 tf.identity, 否则 流控制无效,即 batch_mean 与 batch_var 不依赖 ema_apply_op,返回的便不是更新后的值。
# ema.average returns the Variable holding the average of var.
mean, var = tf.cond(is_training,
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_normalization(inputs, mean, var, beta, gamma, 1e-3)
return normed
若是在训练中,mean 与 var 是每次batch更新后的值,若不在训练中,则用最后更新完的值亦即整个训练集的平均值与方差。
batch_norm_template 函数结束
BN层是对于每个神经元做归一化处理,那么在全连接层的BN就应该这样实现:
def batch_norm_for_fc(inputs, is_training, bn_decay, scope):
""" Batch normalization on FC data.
Args:
inputs: Tensor, 2D BxC input
is_training: boolean tf.Varialbe, true indicates training phase
bn_decay: float or float tensor variable, controling moving average weight
scope: string, variable scope
Return:
normed: batch-normalized maps
"""
return batch_norm_template(inputs, is_training, scope, [0,], bn_decay)
在卷积层上的BN使用,其实也是使用了类似权值共享的策略,把一整张特征图(即每个通道)当做一个神经元进行处理。
def batch_norm_for_conv2d(inputs, is_training, bn_decay, scope):
""" Batch normalization on 2D convolutional maps.
Args:
inputs: Tensor, 4D BHWC input maps
is_training: boolean tf.Varialbe, true indicates training phase
bn_decay: float or float tensor variable, controling moving average weight
scope: string, variable scope
Return:
normed: batch-normalized maps
"""
return batch_norm_template(inputs, is_training, scope, [0,1,2], bn_decay)