摘要
本文探讨 Batch Normalization 在测试或推断时使用的算法及其原理.
相关
配套代码, 请参考文章 :
Python和PyTorch对比实现批标准化 Batch Normalization 函数在测试或推理过程中的算法.
系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
正文
Batch Normalization 在训练时使用每批数据的均值和方差进行数据的规范化, 在测试或推断的时候使用全体数据的特征.
我们不可能在训练前再次遍历全体数据的特征, 耗时太大. 也不可能记录所有批次的中间结果, 内存消耗太大.
目前主流的深度学习框架 TensorFlow 和 PyTorch 等采用的是参数估计加滑动平均法, 并引入的一个超参数来解决这个问题.
接下来我们详细探讨这个方法的理论基础.
1. 分步使用样本特征计算总体特征
1.1 分步使用样本均值计算总体均值
已知一个 k k k 维数组 x x x 的均值为 μ \mu μ, 标准差为 σ \sigma σ, 排列成列数量不相等的矩阵 X X X , 共 m m m 行, 每行 n i n_i ni 个元素, 下标 i i i 表示第 i i i 行.
其形式类似如下 :
x 11 x_{11} x11 | x 12 x_{12} x12 | x 13 x_{13} x13 | x 14 x_{14} x14 | ( n 1 = 4 ) (n_1 = 4) (n1=4) | |
---|---|---|---|---|---|
x 21 x_{21} x21 | x 22 x_{22} x22 | x 23 x_{23} x23 | ( n 2 = 3 ) (n_2 = 3) (n2=3) | ||
x 31 x_{31} x31 | x 32 x_{32} x32 | x 33 x_{33} x33 | x 34 x_{34} x34 | x 35 x_{35} x35 | ( n 3 = 5 ) (n_3 = 5) (n3=5) |
x 41 x_{41} x41 | x 42 x_{42} x42 | ( n 4 = 2 ) (n_4 = 2) (n4=2) | |||
x 51 x_{51} x51 | x 52 x_{52} x52 | x 53 x_{53} x53 | x 54 x_{54} x54 | ( n 5 = 3 ) (n_5 = 3) (n5=3) |
设第 i i i 组的均值为 x ˉ i \bar x_i xˉi, 标准差为 s i s_i si. 求 x ˉ \bar x xˉ 与 μ \mu μ 的关系, 求 s s s 与 σ \sigma σ 的关系.
由题意, 知矩阵的元素的总数量 k k k , 均值 μ \mu μ, 方差 σ 2 \sigma^2 σ2 :
k = ∑ i = 1 m n i    μ = 1 k ∑ i = 1 m ∑ j = 1 n i x i j    σ 2 = 1 k ∑ i = 1 m ∑ j = 1 n i ( x i j − μ ) 2 k = \sum_{i=1}^{m}n_i \;\\ \mu =\frac{1}{k}\sum_{i=1}^{m}\sum_{j=1}^{n_i}x_{ij}\\ \;\\ \sigma^2 = \frac{1}{k}\sum_{i=1}^{m}\sum_{j=1}^{n_i}(x_{ij}-\mu)^2 k=i=1∑mniμ=k1i=1∑mj=1∑nixijσ2=k1i=1∑mj=1∑ni(xij−μ)2
因为 :
∑ i = 1 m ∑ j = 1 n i x i j = ∑ j = 1 n n i x ˉ i \sum_{i=1}^{m}\sum_{j=1}^{n_i}x_{ij}=\sum_{j=1}^{n}n_i\bar x_i\\ i=1∑mj=1∑nixij=j=1∑nnixˉi
μ = 1 k ∑ j = 1 n n i x ˉ i \mu=\frac{1}{k}\sum_{j=1}^{n}n_i\bar x_i μ=k1j=1∑nnixˉi
令 n = 2 n = 2 n=2 :
μ = n 1 x ˉ 1 + n 2 x ˉ 2 ( n 1 + n 2 ) \mu=\frac{n_1\bar x_1 +n_2\bar x_2}{(n_1+n_2)} μ=(n1+n2)n1xˉ1+n2xˉ2
改成迭代式 :
μ i + 1 = k i μ i + n i + 1 x ˉ i + 1 ( k i + n i + 1 ) \mu_{i+1}=\frac{k_i\mu_i +n_{i+1}\bar x_{i+1}}{(k_i+n_{i+1})} μi+1=(ki+ni+1)kiμi