1.Batch normalization的公式
其中: x i x_{i} xi是输入, μ B \mu _{B} μB是均值, σ B 2 \sigma _{B}^{2} σB2是方差, γ γ γ是缩放系数(scale), β β β是偏移(offset)系数, ε \varepsilon ε是方差偏移系数, B N ( x i ) BN(x_{i}) BN(xi)是输出。
2. Batch normalization介绍
批标准化(batch normalization,BN),一般用在激活函数之前,使结果
y
=
w
x
+
b
y=wx+b
y=wx+b,各个维度参数均值为0,方差为1。通过规范化让激活函数的输入分布在线性区间,让每一层的输入有一个稳定的分布会有利于网络的训练。
优点:
- 加大探索步长,加快收敛速度。
- 更容易跳出局部极小。
- 破坏原来的数据分布,一定程度上防止过拟合。
- 解决收敛速度慢和梯度爆炸。
3. Batch normalization的tensorflow API
3.1
mean, variance = tf.nn.moments(x, axes, name=None, keep_dims=False)
计算统计矩,mean 是一阶矩即均值,variance 则是二阶中心矩即方差,axes=[0]表示按列计算;
3.2
tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)
tf.nn.batch_norm_with_global_normalization(x, mean, variance, beta, gamma, variance_epsilon, scale_after_normalization, name=None);
tf.nn.moments 计算返回的 mean 和 variance 作为 tf.nn.batch_normalization 参数调用;