BatchNorm算法详解
1 BatchNorm原理
BatchNorm通过对输入的每个mini-batch的数据进行标准化,使得网络的输入分布更加稳定。
在训练过程中,每轮迭代网络层的输入数据分布变化很大的话,使得数据抖动很大,导致权重变化也会很大,网络很难收敛。而batch norm会将数据归一化,减少不同batch间数据的抖动情况,从而提高训练速度加快收敛。
BatchNorm计算流程
输入: 设一个mini-batch为 B = { x 1... m } \mathcal{B}=\{x_{1...m}\} B={x1...m}, γ , β \gamma,\beta γ,β为可学习的参数
首先计算
B
\mathcal{B}
B的均值:
μ
B
←
1
m
∑
i
=
1
m
x
i
\mu_\mathcal{B} \leftarrow \frac{1}{m} \sum^{m}_{i=1}x_i
μB←m1i=1∑mxi
然后计算
B
\mathcal{B}
B的方差:
σ
B
2
←
1
m
∑
i
=
1
m
(
x
i
−
μ
B
)
2
\sigma^2_\mathcal{B} \leftarrow \frac{1}{m} \sum^{m}_{i=1}(x_i - \mu_\mathcal{B})^2
σB2←m1i=1∑m(xi−μB)2
归一化数据:
x
i
^
←
x
i
−
μ
B
σ
B
2
+
ϵ
\hat{x_i} \leftarrow \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma^2_{\mathcal{B}} + \epsilon}}
xi^←σB2+ϵxi−μB
其中,
ϵ
\epsilon
ϵ的作用是防止方差为0导致出错,
ϵ
\epsilon
ϵ的值为1e-5。
最后,对归一化的数据进行缩放(scale)和平移(shift)
y
i
←
γ
x
i
^
+
β
y_i \leftarrow \gamma \hat{x_i} + \beta
yi←γxi^+β
其中,
γ
,
β
\gamma,\beta
γ,β是通过训练学习到的。
2 BatchNorm代码实现
def batchnorm_forward(x, gamma, beta, bn_param):
"""
Forward pass for batch normalization.
During training the sample mean and (uncorrected) sample variance are
computed from minibatch statistics and used to normalize the incoming data.
During training we also keep an exponentially decaying running mean of the
mean and variance of each feature, and these averages are used to normalize
data at test-time.
At each timestep we update the running averages for mean and variance using
an exponential decay based on the momentum parameter:
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
Input:
- x: Data of shape (N, D)
- gamma: Scale parameter of shape (D,)
- beta: Shift paremeter of shape (D,)
- bn_param: Dictionary with the following keys:
- mode: 'train' or 'test'; required
- eps: Constant for numeric stability
- momentum: Constant for running mean / variance.
- running_mean: Array of shape (D,) giving running mean of features
- running_var Array of shape (D,) giving running variance of features
Returns a tuple of:
- out: of shape (N, D)
- cache: A tuple of values needed in the backward pass
"""
mode = bn_param['mode']
eps = bn_param.get('eps', 1e-5)
momentum = bn_param.get('momentum', 0.9)
N, D = x.shape
running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
running_var = bn_param.get('running_var', np.ones(D, dtype=x.dtype))
if mode == 'train':
sample_mean = x.mean(axis=0)
sample_var = x.var(axis=0)
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
std = np.sqrt(sample_var + eps)
x_centered = x - sample_mean
x_norm = x_centered / std
out = gamma * x_norm + beta
cache = (x_norm, x_centered, std, gamma)
elif mode == 'test':
x_norm = (x - running_mean) / np.sqrt(running_var + eps)
out = gamma * x_norm + beta
else:
raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
# Store the updated running means back into bn_param
bn_param['running_mean'] = running_mean
bn_param['running_var'] = running_var
return out, cache
3 为什么要做滑动平均
我们一开始训练不可能获得整个训练集的均值和方差,
就算我们在训练前,把整个训练集做一次完全的forward,拿到了均值和方差,但是在模型参数变化后,均值和方差也会随之变化。所以我们要通过滑动平均的方法来获取整个训练集的均值和方差。
4 BN中的滑动平均
训练过程中的每一个batch都会进行一次滑动平均的计算:
初始值,moving_mean = 0,moving_var = 1,相当于标准正态分布。理论上初始化为任意值。momentum = 0.9
moving_mean -= (moving_mean - batch_mean) * (1 - momentum)
moving_var -= (moving_var - batch_var) * (1 - momentum)