TensorFlow中batch normalization的实现有两种,一种是tf.layers.batch_normalization,一种是tf.nn.batch_normalization,在此处,我使用的是第二种。
原理就不贴了。
代码如下:
def BN(input, isTraining=False, name='BatchNorm', moving_decay=0.9, eps=1e-4):
#print('BN isTraining: ', isTraining)
_in = input
shape = input.get_shape().as_list()
assert len(shape) in [2, 4]
with tf.variable_scope(name):
gamma = tf.Variable(tf.constant(1.0, dtype=tf.float32, shape=[shape[-1]]), name='gamma')
beta = tf.Variable(tf.constant(0.0, dtype=tf.float32, shape=[shape[-1]]), name='beta')
axes = list(range(len(shape)-1))
batch_mean, batch_var = tf.nn.moments(input, axes=axes)
ema = tf.train.ExponentialMovingAverage(decay=moving_decay)
def mean_var_with_update():
ema_apply_op = ema