转自:https://blog.csdn.net/hanyong4719/article/details/80558995
- def batch_norm_layer(x, train_phase, scope_bn):##x is input-tensor, train_phase is tf.Variable(True/False)
- with tf.variable_scope(scope_bn):
- beta = tf.Variable(tf.constant(0.0, shape=[x.shape[-1]]), name='beta', trainable=True)
- gamma = tf.Variable(tf.constant(1.0, shape=[x.shape[-1]]), name='gamma', trainable=True)
- axises =list(range(len(x.shape) - 1))# np.arange(len(x.shape) - 1)
- batch_mean, batch_var = tf.nn.moments(x, axises, name='moments')
- ema = tf.train.ExponentialMovingAverage(decay=0.5)
- def mean_var_with_update():
- ema_apply_op = ema.apply([batch_mean, batch_var])
- with tf.control_dependencies([ema_apply_op]):
- return tf.identity(batch_mean), tf.identity(batch_var)
- mean, var = tf.cond(train_phase, mean_var_with_update,lambda: (ema.average(batch_mean), ema.average(batch_var)))
- normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
- return normed