Batch Normalization 和 Batch Renormalization 前向和反向公式详细推导

Batch Normalization 和 Batch Renormalization 前向和反向公式详细推导

CSDN的Latex引擎好像有点问题,有些语句刷不出来,,,如有需要可以看我的博客园同文https://www.cnblogs.com/lyc-seu/p/12676505.html

一、BN前向传播

根据论文‘’Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" 的推导过程,主要有下面四个公式:

μ B = 1 m ∑ i m x i (1) \mu_B=\frac{1}{m}\sum_i^mx_i\tag{1} μB=m1imxi(1)

δ B 2 = 1 m ∑ i m ( x i − μ B ) 2 (2) \delta_B^2=\frac{1}{m}\sum_i^m(x_i-\mu_B)^2\tag{2} δB2=m1im(xiμB)2(2)

x i ^ = x i − μ B δ B 2 + ϵ (3) \widehat{x_i}=\frac{x_i-\mu_B}{\sqrt{\delta_B^2+\epsilon}}\tag{3} xi =δB2+ϵ xiμB(3)

y i = γ x i ^ + β (4) y_i=\gamma\widehat{x_i}+\beta\tag{4} yi=γxi +β(4)

以MLP为例,假设输入的mini-batch样本数为 m m m,则此处的 x i , i = 1 , 2 , . . . m x_i,i=1,2,...m xi,i=1,2,...m是第 i i i个样本对应的某一层激活值中的一个激活值。也就是说,假设输入 m m m个样本作为一次训练,其中第 i i i个样本输入网络后,在 l l l层得到了 N N N个激活单元,则 x i x_i xi代表其中任意一个激活单元。事实上应该写为 x i l ( n ) x_i^l(n) xil(n)更为直观。

所以BN实际上就是对第 l l l层的第 n n n个激活单元 x i l ( n ) x_i^l(n) xil(n)求其在一个batch中的平均值和方差,并对其进行标准归一化,得到 x i l ( n ) ^ \widehat{x_i^l(n)} xil(n) ,可知归一化后的m个激活单元均值为0方差为1,一定程度上消除了Internal Covariate Shift,减少了网络的各层激活值在训练样本上的边缘分布的变化。

二、BN的反向传播

  • 设前一层的梯度为 ∂ L ∂ y i \frac{\partial{L}}{\partial{y_i}} yiL.
  • 需要计算 ∂ L ∂ x i , ∂ L ∂ γ 以 及 ∂ L ∂ β \frac{\partial{L}}{\partial{x_i}},\frac{\partial{L}}{\partial{\gamma}}以及\frac{\partial{L}}{\partial{\beta}} xiL,γLβL

由链式法则以及公式\eqref{4}:

∂ L ∂ γ = ∂ L ∂ y i ∂ y i ∂ γ = ∂ L ∂ y i x i ^ (5) \frac{\partial{L}}{\partial{\gamma}}=\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\gamma}}=\frac{\partial{L}}{\partial{y_i}}\widehat{x_i} \tag{5} γL=yiLγyi=yiLxi (5)

由于对于所有 i = 1 , 2... m . ∂ L ∂ y i x i ^ 对 ∂ L ∂ γ i=1,2...m. \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}对\frac{\partial{L}}{\partial{\gamma}} i=1,2...m.yiLxi γL均有贡献,因此一个batch的训练中将 ∂ L ∂ γ \frac{\partial{L}}{\partial{\gamma}} γL定义为:

∂ L ∂ γ = ∑ i = 1 m ∂ L ∂ y i x i ^ (6) \frac{\partial{L}}{\partial{\gamma}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}\tag{6} γL=i=1myiLxi (6)

同样有:

∂ L ∂ β = ∑ i = 1 m ∂ L ∂ y i (7) \frac{\partial{L}}{\partial{\beta}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\tag{7} βL=i=1myiL(7)

另外,求 ∂ L ∂ x i \frac{\partial{L}}{\partial{x_i}} xiL过程则较为复杂。根据链式法则,以及公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{3},将 x i ^ \widehat{x_i} xi 视为 g ( x i , δ B 2 , μ B ) g(x_i,\delta_B^2,\mu_B) g(xi,δB2,μB)有:

∂ L ∂ x i = ∂ L ∂ y i ∂ y i ∂ x i ^ ( ∂ x i ^ ∂ x i + ∂ x i ^ ∂ δ B 2 ∂ δ B 2 ∂ x i + ∂ x i ^ ∂ μ B ∂ μ B ∂ x i ) = ∂ L ∂ y i ∂ y i ∂ x i ^ ( g 1 ′ + g 2 ′ ∂ δ B 2 ∂ x i + g 3 ′ ∂ μ B ∂ x i ) (8) \frac{\partial{L}}{\partial{x_i}}=\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\widehat{x_i}}}(\frac{\partial{\widehat{x_i}}}{\partial{x_i}}+\frac{\partial{\widehat{x_i}}}{\partial{\delta_B^2}}\frac{\partial{\delta_B^2}}{\partial{x_i}}+\frac{\partial{\widehat{x_i}}}{\partial{\mu_B}}\frac{\partial{\mu_B}}{\partial{x_i}}) =\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\widehat{x_i}}}(g_1'+g_2'\frac{\partial{\delta_B^2}}{\partial{x_i}}+g_3'\frac{\partial{\mu_B}}{\partial{x_i}}) \tag{8} xiL=yiLxi yi(xixi +δB2xi xiδB2+μBxi xiμB)=yiLxi yi(g1+g2xiδB2+g3xiμB)(8)

而因为公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{2}可知上式括号中的第二项求偏导可以进一步拆分。(将 δ B 2 \delta_B^2 δB2视为 f ( x i , μ B ) f(x_i,\mu_B) f(xi,μB)

∂ δ B 2 ∂ x i = ∂ δ B 2 ∂ x i + ∂ δ B 2 ∂ μ B ∂ μ B ∂ x i = f 1 ′ + f 2 ′ ∂ μ B ∂ x i (9) \frac{\partial{\delta_B^2}}{\partial{x_i}}= \frac{\partial{\delta_B^2}}{\partial{x_i}}+ \frac{\partial{\delta_B^2}}{\partial{\mu_B}} \frac{\partial{\mu_B}}{\partial{x_i}}= f_1'+f_2'\frac{\partial{\mu_B}}{\partial{x_i}} \tag{9} xiδB2=xiδB2+μBδB2xiμB=f1+f2xiμB(9)

注意公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{9}中的两个 ∂ δ B 2 ∂ x i \frac{\partial{\delta_B^2}}{\partial{x_i}} xiδB2代表不同的含义。由公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{8},\eqref{9}可知,只要求出 f 1 ′ , f 2 ′ , g 1 ′ , g 2 ′ , g 3 ′ , ∂ μ B ∂ x i , ∂ y i ∂ x i ^ f_1',f_2',g_1',g_2',g_3',\frac{\partial{\mu_B}}{\partial{x_i}},\frac{\partial{y_i}}{\partial{\widehat{x_i}}} f1,f2,g1,g2,g3,xiμB,xi yi.即可求出 ∂ L ∂ x i \frac{\partial{L}}{\partial{x_i}} xiL.

原论文中将公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{8}拆分成如下几项:

∂ L ∂ x i = ∂ L ∂ x i ^ ∂ x i ^ ∂ x i + ∂ L ∂ δ B 2 ∂ δ B 2 ∂ x i + ∂ L ∂ μ B ∂ μ B ∂ x i (10) \frac{\partial{L}}{\partial{x_i}}= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\partial{\widehat{x_i}}}{\partial{x_i}}+ \frac{\partial{L}}{\partial{\delta_B^2}} \frac{\partial{\delta_B^2}}{\partial{x_i}}+ \frac{\partial{L}}{\partial{\mu_B}} \frac{\partial{\mu_B}}{\partial{x_i}} \tag{10} xiL=xi Lxixi +δB2LxiδB2+μBLxiμB(10)

其中:
∂ L ∂ x i ^ = ∂ L ∂ y i ∂ y i ∂ x i ^ = ∂ L ∂ y i γ (10.1) \frac{\partial{L}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \frac{\partial{y_i}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \gamma\tag{10.1} xi L=yiLxi yi=yiLγ(10.1)

∂ x i ^ ∂ x i = g 1 ′ = 1 δ B 2 + ϵ (10.2) \frac{\partial{\widehat{x_i}}}{\partial{x_i}}=g'_1=\frac{1}{\sqrt{\delta_B^2+\epsilon}} \tag{10.2} xixi =g1=δB2+ϵ 1(10.2)

∂ L ∂ δ B 2 = ∂ L ∂ x i ^ g 2 ′ = ∂ L ∂ x i ^ μ B − x i 2 ( δ B 2 + ϵ ) − 3 2 ⟶ \frac{\partial{L}}{\partial{\delta_B^2}}= \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_2= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\mu_B-x_i}{2}(\delta_B^2+\epsilon)^{-\frac{3}{2}} \longrightarrow δB2L=xi Lg2=xi L2μBxi(δB2+ϵ)23
∑ i = 1 m ∂ L ∂ x i ^ μ B − x i 2 ( δ B 2 + ϵ ) − 3 2 (10.3) \sum_{i=1}^m\frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\mu_B-x_i}{2}(\delta_B^2+\epsilon)^{-\frac{3}{2}} \tag{10.3} i=1mxi L2μBxi(δB2+ϵ)23(10.3)

∂ δ B 2 ∂ x i = f 1 ′ = 2 ( x i − μ B ) m (10.4) \frac{\partial{\delta_B^2}}{\partial{x_i}}=f'_1=\frac{2(x_i-\mu_B)}{m} \tag{10.4} xiδB2=f1=m2(xiμB)(10.4)

∂ L ∂ μ B = ∂ L ∂ x i ^ g 3 ′ + ∂ L ∂ x i ^ g 2 ′ f 2 ′ ⟶ \frac{\partial{L}}{\partial{\mu_B}}= \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_3+ \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_2f'_2 \longrightarrow μBL=xi Lg3+xi Lg2f2
∑ i = 1 m ( ∂ L ∂ x i ^ − 1 δ B 2 + ϵ + ∂ L ∂ δ B 2 2 ( μ B − x i ) m ) (10.5) \sum_{i=1}^m( \frac{\partial{L}}{\partial{\widehat{x_i}}}\frac{-1}{\sqrt{\delta_B^2+\epsilon}} +\frac{\partial{L}}{\partial{\delta_B^2}}\frac{2(\mu_B-x_i)}{m}) \tag{10.5} i=1m(xi LδB2+ϵ 1+δB2Lm2(μBxi))(10.5)

∂ μ B ∂ x i = 1 m (10.6) \frac{\partial{\mu_B}}{\partial{x_i}}=\frac{1}{m} \tag{10.6} xiμB=m1(10.6)

最终BN的反向过程由公式KaTeX parse error: Undefined control sequence: \eqref at position 1: \̲e̲q̲r̲e̲f̲{6},\eqref{7},\…给出。

三、Batch Renormalization

参照论文—— Batch Renormalization: Towards Reducing Minibatch Dependence
in Batch-Normalized Models

Batch Renormalization是对传统BN的优化,该方法保证了train和inference阶段的等效性,解决了非独立同分布和小minibatch的问题。

1、前向

跟原来的公式类似,添加了两个非训练参数 r , d r,d r,d:
μ B = 1 m ∑ i m x i (1.1) \mu_B=\frac{1}{m}\sum_i^mx_i\tag{1.1} μB=m1imxi(1.1)

σ B = ϵ + 1 m ∑ i m ( x i − μ B ) 2 (2.1) \sigma_B=\sqrt{\epsilon+\frac{1}{m}\sum_i^m(x_i-\mu_B)^2}\tag{2.1} σB=ϵ+m1im(xiμB)2 (2.1)

x i ^ = x i − μ B σ B r + d (3.1) \widehat{x_i}=\frac{x_i-\mu_B}{\sigma_B}r+d\tag{3.1} xi =σBxiμBr+d(3.1)

y i = γ x i ^ + β (4.1) y_i=\gamma\widehat{x_i}+\beta\tag{4.1} yi=γxi +β(4.1)

r = S t o p _ G r a d i e n t ( C l i p [ 1 / r m a x , r m a x ] ( σ B σ ) ) (5.1) r=Stop\_Gradient(Clip_{[1/r_{max} ,r_{max}]}(\frac{\sigma_B}{\sigma}))\tag{5.1} r=Stop_Gradient(Clip[1/rmax,rmax](σσB))(5.1)

d = S t o p _ G r a d i e n t ( C l i p [ − d m a x , d m a x ] ( μ B − μ σ ) ) (6.1) d=Stop\_Gradient(Clip_{[-d_{max} ,d_{max}]}(\frac{\mu_B-\mu}{\sigma}))\tag{6.1} d=Stop_Gradient(Clip[dmax,dmax](σμBμ))(6.1)


Update moving averages:

μ : = μ + α ( μ B − μ ) (7.1) \mu:=\mu+\alpha(\mu_B-\mu)\tag{7.1} μ:=μ+α(μBμ)(7.1)

σ : = σ + α ( σ B − σ ) (8.1) \sigma:=\sigma+\alpha(\sigma_B-\sigma)\tag{8.1} σ:=σ+α(σBσ)(8.1)

Inference:

y = γ x − μ σ + β (9.1) y=\gamma\frac{x-\mu}{\sigma}+\beta\tag{9.1} y=γσxμ+β(9.1)

相比于之前的BN只在训练时计算滑动均值与方差,推断时才使用他们;BRN在训练和推断时都用到了滑动均值与方差。

2、反向

反向的推导与BN类似,
∂ L ∂ x i ^ = ∂ L ∂ y i ∂ y i ∂ x i ^ = ∂ L ∂ y i γ (10.11) \frac{\partial{L}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \frac{\partial{y_i}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \gamma\tag{10.11} xi L=yiLxi yi=yiLγ(10.11)

∂ L ∂ σ B ⟶ ∑ i = 1 m ∂ L ∂ x i ^ − r ( x i − μ B ) σ B 2 (10.22) \frac{\partial{L}}{\partial{\sigma_B}} \longrightarrow\sum_{i=1}^m \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{-r(x_i-\mu_B)}{\sigma_B^2} \tag{10.22} σBLi=1mxi LσB2r(xiμB)(10.22)

∂ L ∂ μ B ⟶ ∑ i = 1 m ∂ L ∂ x i ^ − r σ B (10.33) \frac{\partial{L}}{\partial{\mu_B}}\longrightarrow\sum_{i=1}^{m}\frac{\partial{L}}{\partial{\widehat{x_i}}}\frac{-r}{\sigma_B} \tag{10.33} μBLi=1mxi LσBr(10.33)

∂ L ∂ x i = ∂ L ∂ x i ^ r σ B + ∂ L ∂ σ B x i − μ B m σ B + ∂ L ∂ μ B 1 m (10.44) \frac{\partial{L}}{\partial{x_i}}= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{r}{\sigma_B}+ \frac{\partial{L}}{\partial{\sigma_B}} \frac{x_i-\mu_B}{m\sigma_B}+ \frac{\partial{L}}{\partial{\mu_B}} \frac{1}{m} \tag{10.44} xiL=xi LσBr+σBLmσBxiμB+μBLm1(10.44)

∂ L ∂ γ = ∑ i = 1 m ∂ L ∂ y i x i ^ (10.55) \frac{\partial{L}}{\partial{\gamma}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}\tag{10.55} γL=i=1myiLxi (10.55)

∂ L ∂ β = ∑ i = 1 m ∂ L ∂ y i (10.66) \frac{\partial{L}}{\partial{\beta}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\tag{10.66} βL=i=1myiL(10.66)

三、卷积网络中的BN

​ 上面的推导过程都是基于MLP的。对于卷积网络而言,BN过程中的m个激活单元被推广为m幅特征图像。 假设某一层卷积后的feature map是 [ N , H , W , C ] [N,H,W,C] [N,H,W,C]的张量,其中N表示batch数目,H,W分别表示长和宽,C表示特征通道数。则对卷积网络的BN操作时,令 m = N × H × W m = N\times H\times W m=N×H×W,也就是说将第 i i i个batch内某一通道 c c c上的任意一个特征图像素点视为 x i x_i xi,套用上面的BN公式即可。所以对于卷积网络来说,中间激活层每个通道都对应一组BN参数 γ , β \gamma,\beta γ,β.

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

love小酒窝

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值