Batch Normalization 和 Batch Renormalization 前向和反向公式详细推导
CSDN的Latex引擎好像有点问题,有些语句刷不出来,,,如有需要可以看我的博客园同文https://www.cnblogs.com/lyc-seu/p/12676505.html
文章目录
一、BN前向传播
根据论文‘’Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" 的推导过程,主要有下面四个公式:
μ B = 1 m ∑ i m x i (1) \mu_B=\frac{1}{m}\sum_i^mx_i\tag{1} μB=m1i∑mxi(1)
δ B 2 = 1 m ∑ i m ( x i − μ B ) 2 (2) \delta_B^2=\frac{1}{m}\sum_i^m(x_i-\mu_B)^2\tag{2} δB2=m1i∑m(xi−μB)2(2)
x i ^ = x i − μ B δ B 2 + ϵ (3) \widehat{x_i}=\frac{x_i-\mu_B}{\sqrt{\delta_B^2+\epsilon}}\tag{3} xi =δB2+ϵxi−μB(3)
y i = γ x i ^ + β (4) y_i=\gamma\widehat{x_i}+\beta\tag{4} yi=γxi +β(4)
以MLP为例,假设输入的mini-batch样本数为 m m m,则此处的 x i , i = 1 , 2 , . . . m x_i,i=1,2,...m xi,i=1,2,...m是第 i i i个样本对应的某一层激活值中的一个激活值。也就是说,假设输入 m m m个样本作为一次训练,其中第 i i i个样本输入网络后,在 l l l层得到了 N N N个激活单元,则 x i x_i xi代表其中任意一个激活单元。事实上应该写为 x i l ( n ) x_i^l(n) xil(n)更为直观。
所以BN实际上就是对第 l l l层的第 n n n个激活单元 x i l ( n ) x_i^l(n) xil(n)求其在一个batch中的平均值和方差,并对其进行标准归一化,得到 x i l ( n ) ^ \widehat{x_i^l(n)} xil(n) ,可知归一化后的m个激活单元均值为0方差为1,一定程度上消除了Internal Covariate Shift,减少了网络的各层激活值在训练样本上的边缘分布的变化。
二、BN的反向传播
- 设前一层的梯度为 ∂ L ∂ y i \frac{\partial{L}}{\partial{y_i}} ∂yi∂L.
- 需要计算 ∂ L ∂ x i , ∂ L ∂ γ 以 及 ∂ L ∂ β \frac{\partial{L}}{\partial{x_i}},\frac{\partial{L}}{\partial{\gamma}}以及\frac{\partial{L}}{\partial{\beta}} ∂xi∂L,∂γ∂L以及∂β∂L
由链式法则以及公式\eqref{4}:
∂ L ∂ γ = ∂ L ∂ y i ∂ y i ∂ γ = ∂ L ∂ y i x i ^ (5) \frac{\partial{L}}{\partial{\gamma}}=\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\gamma}}=\frac{\partial{L}}{\partial{y_i}}\widehat{x_i} \tag{5} ∂γ∂L=∂yi∂L∂γ∂yi=∂yi∂Lxi (5)
由于对于所有 i = 1 , 2... m . ∂ L ∂ y i x i ^ 对 ∂ L ∂ γ i=1,2...m. \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}对\frac{\partial{L}}{\partial{\gamma}} i=1,2...m.∂yi∂Lxi 对∂γ∂L均有贡献,因此一个batch的训练中将 ∂ L ∂ γ \frac{\partial{L}}{\partial{\gamma}} ∂γ∂L定义为:
∂ L ∂ γ = ∑ i = 1 m ∂ L ∂ y i x i ^ (6) \frac{\partial{L}}{\partial{\gamma}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}\tag{6} ∂γ∂L=i=1∑m∂yi∂Lxi (6)
同样有:
∂ L ∂ β = ∑ i = 1 m ∂ L ∂ y i (7) \frac{\partial{L}}{\partial{\beta}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\tag{7} ∂β∂L=i=1∑m∂yi∂L(7)
另外,求 ∂ L ∂ x i \frac{\partial{L}}{\partial{x_i}} ∂xi∂L过程则较为复杂。根据链式法则,以及公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{3},将 x i ^ \widehat{x_i} xi 视为 g ( x i , δ B 2 , μ B ) g(x_i,\delta_B^2,\mu_B) g(xi,δB2,μB)有:
∂ L ∂ x i = ∂ L ∂ y i ∂ y i ∂ x i ^ ( ∂ x i ^ ∂ x i + ∂ x i ^ ∂ δ B 2 ∂ δ B 2 ∂ x i + ∂ x i ^ ∂ μ B ∂ μ B ∂ x i ) = ∂ L ∂ y i ∂ y i ∂ x i ^ ( g 1 ′ + g 2 ′ ∂ δ B 2 ∂ x i + g 3 ′ ∂ μ B ∂ x i ) (8) \frac{\partial{L}}{\partial{x_i}}=\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\widehat{x_i}}}(\frac{\partial{\widehat{x_i}}}{\partial{x_i}}+\frac{\partial{\widehat{x_i}}}{\partial{\delta_B^2}}\frac{\partial{\delta_B^2}}{\partial{x_i}}+\frac{\partial{\widehat{x_i}}}{\partial{\mu_B}}\frac{\partial{\mu_B}}{\partial{x_i}}) =\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\widehat{x_i}}}(g_1'+g_2'\frac{\partial{\delta_B^2}}{\partial{x_i}}+g_3'\frac{\partial{\mu_B}}{\partial{x_i}}) \tag{8} ∂xi∂L=∂yi∂L∂xi ∂yi(∂xi∂xi +∂δB2∂xi ∂xi∂δB2+∂μB∂xi ∂xi∂μB)=∂yi∂L∂xi ∂yi(g1′+g2′∂xi∂δB2+g3′∂xi∂μB)(8)
而因为公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{2}可知上式括号中的第二项求偏导可以进一步拆分。(将 δ B 2 \delta_B^2 δB2视为 f ( x i , μ B ) f(x_i,\mu_B) f(xi,μB))
∂ δ B 2 ∂ x i = ∂ δ B 2 ∂ x i + ∂ δ B 2 ∂ μ B ∂ μ B ∂ x i = f 1 ′ + f 2 ′ ∂ μ B ∂ x i (9) \frac{\partial{\delta_B^2}}{\partial{x_i}}= \frac{\partial{\delta_B^2}}{\partial{x_i}}+ \frac{\partial{\delta_B^2}}{\partial{\mu_B}} \frac{\partial{\mu_B}}{\partial{x_i}}= f_1'+f_2'\frac{\partial{\mu_B}}{\partial{x_i}} \tag{9} ∂xi∂δB2=∂xi∂δB2+∂μB∂δB2∂xi∂μB=f1′+f2′∂xi∂μB(9)
注意公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{9}中的两个 ∂ δ B 2 ∂ x i \frac{\partial{\delta_B^2}}{\partial{x_i}} ∂xi∂δB2代表不同的含义。由公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{8},\eqref{9}可知,只要求出 f 1 ′ , f 2 ′ , g 1 ′ , g 2 ′ , g 3 ′ , ∂ μ B ∂ x i , ∂ y i ∂ x i ^ f_1',f_2',g_1',g_2',g_3',\frac{\partial{\mu_B}}{\partial{x_i}},\frac{\partial{y_i}}{\partial{\widehat{x_i}}} f1′,f2′,g1′,g2′,g3′,∂xi∂μB,∂xi ∂yi.即可求出 ∂ L ∂ x i \frac{\partial{L}}{\partial{x_i}} ∂xi∂L.
原论文中将公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{8}拆分成如下几项:
∂ L ∂ x i = ∂ L ∂ x i ^ ∂ x i ^ ∂ x i + ∂ L ∂ δ B 2 ∂ δ B 2 ∂ x i + ∂ L ∂ μ B ∂ μ B ∂ x i (10) \frac{\partial{L}}{\partial{x_i}}= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\partial{\widehat{x_i}}}{\partial{x_i}}+ \frac{\partial{L}}{\partial{\delta_B^2}} \frac{\partial{\delta_B^2}}{\partial{x_i}}+ \frac{\partial{L}}{\partial{\mu_B}} \frac{\partial{\mu_B}}{\partial{x_i}} \tag{10} ∂xi∂L=∂xi ∂L∂xi∂xi +∂δB2∂L∂xi∂δB2+∂μB∂L∂xi∂μB(10)
其中:
∂
L
∂
x
i
^
=
∂
L
∂
y
i
∂
y
i
∂
x
i
^
=
∂
L
∂
y
i
γ
(10.1)
\frac{\partial{L}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \frac{\partial{y_i}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \gamma\tag{10.1}
∂xi
∂L=∂yi∂L∂xi
∂yi=∂yi∂Lγ(10.1)
∂ x i ^ ∂ x i = g 1 ′ = 1 δ B 2 + ϵ (10.2) \frac{\partial{\widehat{x_i}}}{\partial{x_i}}=g'_1=\frac{1}{\sqrt{\delta_B^2+\epsilon}} \tag{10.2} ∂xi∂xi =g1′=δB2+ϵ1(10.2)
∂
L
∂
δ
B
2
=
∂
L
∂
x
i
^
g
2
′
=
∂
L
∂
x
i
^
μ
B
−
x
i
2
(
δ
B
2
+
ϵ
)
−
3
2
⟶
\frac{\partial{L}}{\partial{\delta_B^2}}= \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_2= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\mu_B-x_i}{2}(\delta_B^2+\epsilon)^{-\frac{3}{2}} \longrightarrow
∂δB2∂L=∂xi
∂Lg2′=∂xi
∂L2μB−xi(δB2+ϵ)−23⟶
∑
i
=
1
m
∂
L
∂
x
i
^
μ
B
−
x
i
2
(
δ
B
2
+
ϵ
)
−
3
2
(10.3)
\sum_{i=1}^m\frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\mu_B-x_i}{2}(\delta_B^2+\epsilon)^{-\frac{3}{2}} \tag{10.3}
i=1∑m∂xi
∂L2μB−xi(δB2+ϵ)−23(10.3)
∂ δ B 2 ∂ x i = f 1 ′ = 2 ( x i − μ B ) m (10.4) \frac{\partial{\delta_B^2}}{\partial{x_i}}=f'_1=\frac{2(x_i-\mu_B)}{m} \tag{10.4} ∂xi∂δB2=f1′=m2(xi−μB)(10.4)
∂
L
∂
μ
B
=
∂
L
∂
x
i
^
g
3
′
+
∂
L
∂
x
i
^
g
2
′
f
2
′
⟶
\frac{\partial{L}}{\partial{\mu_B}}= \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_3+ \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_2f'_2 \longrightarrow
∂μB∂L=∂xi
∂Lg3′+∂xi
∂Lg2′f2′⟶
∑
i
=
1
m
(
∂
L
∂
x
i
^
−
1
δ
B
2
+
ϵ
+
∂
L
∂
δ
B
2
2
(
μ
B
−
x
i
)
m
)
(10.5)
\sum_{i=1}^m( \frac{\partial{L}}{\partial{\widehat{x_i}}}\frac{-1}{\sqrt{\delta_B^2+\epsilon}} +\frac{\partial{L}}{\partial{\delta_B^2}}\frac{2(\mu_B-x_i)}{m}) \tag{10.5}
i=1∑m(∂xi
∂LδB2+ϵ−1+∂δB2∂Lm2(μB−xi))(10.5)
∂ μ B ∂ x i = 1 m (10.6) \frac{\partial{\mu_B}}{\partial{x_i}}=\frac{1}{m} \tag{10.6} ∂xi∂μB=m1(10.6)
最终BN的反向过程由公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{6},\eqref{7},\…给出。
三、Batch Renormalization
参照论文—— Batch Renormalization: Towards Reducing Minibatch Dependence
in Batch-Normalized Models
Batch Renormalization是对传统BN的优化,该方法保证了train和inference阶段的等效性,解决了非独立同分布和小minibatch的问题。
1、前向
跟原来的公式类似,添加了两个非训练参数
r
,
d
r,d
r,d:
μ
B
=
1
m
∑
i
m
x
i
(1.1)
\mu_B=\frac{1}{m}\sum_i^mx_i\tag{1.1}
μB=m1i∑mxi(1.1)
σ B = ϵ + 1 m ∑ i m ( x i − μ B ) 2 (2.1) \sigma_B=\sqrt{\epsilon+\frac{1}{m}\sum_i^m(x_i-\mu_B)^2}\tag{2.1} σB=ϵ+m1i∑m(xi−μB)2(2.1)
x i ^ = x i − μ B σ B r + d (3.1) \widehat{x_i}=\frac{x_i-\mu_B}{\sigma_B}r+d\tag{3.1} xi =σBxi−μBr+d(3.1)
y i = γ x i ^ + β (4.1) y_i=\gamma\widehat{x_i}+\beta\tag{4.1} yi=γxi +β(4.1)
r = S t o p _ G r a d i e n t ( C l i p [ 1 / r m a x , r m a x ] ( σ B σ ) ) (5.1) r=Stop\_Gradient(Clip_{[1/r_{max} ,r_{max}]}(\frac{\sigma_B}{\sigma}))\tag{5.1} r=Stop_Gradient(Clip[1/rmax,rmax](σσB))(5.1)
d = S t o p _ G r a d i e n t ( C l i p [ − d m a x , d m a x ] ( μ B − μ σ ) ) (6.1) d=Stop\_Gradient(Clip_{[-d_{max} ,d_{max}]}(\frac{\mu_B-\mu}{\sigma}))\tag{6.1} d=Stop_Gradient(Clip[−dmax,dmax](σμB−μ))(6.1)
Update moving averages:
μ : = μ + α ( μ B − μ ) (7.1) \mu:=\mu+\alpha(\mu_B-\mu)\tag{7.1} μ:=μ+α(μB−μ)(7.1)
σ : = σ + α ( σ B − σ ) (8.1) \sigma:=\sigma+\alpha(\sigma_B-\sigma)\tag{8.1} σ:=σ+α(σB−σ)(8.1)
Inference:
y = γ x − μ σ + β (9.1) y=\gamma\frac{x-\mu}{\sigma}+\beta\tag{9.1} y=γσx−μ+β(9.1)
相比于之前的BN只在训练时计算滑动均值与方差,推断时才使用他们;BRN在训练和推断时都用到了滑动均值与方差。
2、反向
反向的推导与BN类似,
∂
L
∂
x
i
^
=
∂
L
∂
y
i
∂
y
i
∂
x
i
^
=
∂
L
∂
y
i
γ
(10.11)
\frac{\partial{L}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \frac{\partial{y_i}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \gamma\tag{10.11}
∂xi
∂L=∂yi∂L∂xi
∂yi=∂yi∂Lγ(10.11)
∂ L ∂ σ B ⟶ ∑ i = 1 m ∂ L ∂ x i ^ − r ( x i − μ B ) σ B 2 (10.22) \frac{\partial{L}}{\partial{\sigma_B}} \longrightarrow\sum_{i=1}^m \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{-r(x_i-\mu_B)}{\sigma_B^2} \tag{10.22} ∂σB∂L⟶i=1∑m∂xi ∂LσB2−r(xi−μB)(10.22)
∂ L ∂ μ B ⟶ ∑ i = 1 m ∂ L ∂ x i ^ − r σ B (10.33) \frac{\partial{L}}{\partial{\mu_B}}\longrightarrow\sum_{i=1}^{m}\frac{\partial{L}}{\partial{\widehat{x_i}}}\frac{-r}{\sigma_B} \tag{10.33} ∂μB∂L⟶i=1∑m∂xi ∂LσB−r(10.33)
∂ L ∂ x i = ∂ L ∂ x i ^ r σ B + ∂ L ∂ σ B x i − μ B m σ B + ∂ L ∂ μ B 1 m (10.44) \frac{\partial{L}}{\partial{x_i}}= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{r}{\sigma_B}+ \frac{\partial{L}}{\partial{\sigma_B}} \frac{x_i-\mu_B}{m\sigma_B}+ \frac{\partial{L}}{\partial{\mu_B}} \frac{1}{m} \tag{10.44} ∂xi∂L=∂xi ∂LσBr+∂σB∂LmσBxi−μB+∂μB∂Lm1(10.44)
∂ L ∂ γ = ∑ i = 1 m ∂ L ∂ y i x i ^ (10.55) \frac{\partial{L}}{\partial{\gamma}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}\tag{10.55} ∂γ∂L=i=1∑m∂yi∂Lxi (10.55)
∂ L ∂ β = ∑ i = 1 m ∂ L ∂ y i (10.66) \frac{\partial{L}}{\partial{\beta}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\tag{10.66} ∂β∂L=i=1∑m∂yi∂L(10.66)
三、卷积网络中的BN
上面的推导过程都是基于MLP的。对于卷积网络而言,BN过程中的m个激活单元被推广为m幅特征图像。 假设某一层卷积后的feature map是 [ N , H , W , C ] [N,H,W,C] [N,H,W,C]的张量,其中N表示batch数目,H,W分别表示长和宽,C表示特征通道数。则对卷积网络的BN操作时,令 m = N × H × W m = N\times H\times W m=N×H×W,也就是说将第 i i i个batch内某一通道 c c c上的任意一个特征图像素点视为 x i x_i xi,套用上面的BN公式即可。所以对于卷积网络来说,中间激活层每个通道都对应一组BN参数 γ , β \gamma,\beta γ,β.