Batch-Normalization有三种定义格式,第一种格式是低级版本,需要先计算均值和方差。后面的两种是封装后的,可以直接使用,下面分别介绍:
1、tf.nn.batch_normalization
这个函数实现batch_normalization需要两步,分装程度较低,一般不使用
(1)tf.nn.moments(x, axes, name=None, keep_dims=False) ⇒ mean, variance:
统计矩,mean 是一阶矩,variance 则是二阶中心矩
(2)tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)
由函数接口可知,tf.nn.moments 计算返回的 mean 和 variance 作为 tf.nn.batch_normalization 参 数进一步调用;
如我们需计算的 tensor 的 shape 为一个四元组 [batch_size, height, width, kernels],一个示例程序如下:
import tensorflow as tf
shape = [128, 32, 32, 64]
a = tf.Variable(tf.random_normal(shape)) # a:activations
axis = list(range(len(shape)-1)) # len(x.get_shape())
a_mean, a_var = tf.nn.moments(a, axis)
得到a_mean和a_var以后可以进入第二步计算:
tf.nn.batch_normalization(
x, #输入
mean = a_mean,
variance = a_var,
offset, #tensor,偏移量
scale, # tensor,尺度缩放值
variance_epsilon, #避免除0
name=None)
完整的实现