作者给出的批标准化的算法如下:
算法中的ε是一个常量,为了保证数值的稳定性
所以:
因为:
所以:
因为:
和
所以:
所以:
对于BN变换是可微分的,随着网络的训练,网络层可以持续学到输入的分布。
(这里方差采用的是无偏方差估计) 所以推断采用BN的方式为:
作者给出的完整算法:
![这里写图片描述](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110554169.png)
反向传播求梯度:
因为:
![$${y^{(k)}} = {\gamma ^{(k)}}{\widehat x^{(k)}} + {\beta ^{(k)}}$$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110549120.gif)
![$${{\partial l} \over {\partial {{\widehat x}_i}}} = {{\partial l} \over {\partial {y_i}}}\gamma $$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110555174.gif)
![$${\widehat x_i} = {{{x_i} - {\mu _B}} \over {\sqrt {\sigma _B^2 + \varepsilon } }}$$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110556179.gif)
![$${{\partial l} \over {\partial \sigma _B^2}} = {\sum\limits_{i = 1}^m {{{\partial l} \over {\partial {{\widehat x}_i}}}({x_i} - {u_B}){{ - 1} \over 2}(\sigma _B^2 + \varepsilon )} ^{ - {3 \over 2}}}$$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110556186.gif)
![$${{\partial l} \over {\partial {u_B}}} = \sum\limits_{{\rm{i = 1}}}^m {{{\partial l} \over {\partial {{\widehat x}_i}}}} {{ - 1} \over {\sqrt {\sigma _B^2 + \varepsilon } }}$$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110557192.gif)
![$${\mu _B} = {1 \over m}\sum\limits_{i = 1}^m {{x_i}} $$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110557195.gif)
![$$\sigma _B^2 = {1 \over m}\sum\limits_{i = 1}^m {({x_i}} - {\mu _B}{)^2}$$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110558198.gif)
![$${{\partial l} \over {\partial {x_i}}} = {{\partial l} \over {\partial {{\widehat x}_i}}}{1 \over {\sqrt {\sigma _B^2 + \varepsilon } }} + {{\partial l} \over {\partial \sigma _B^2}}{{2({x_i} - {\mu _B})} \over m} + {{\partial l} \over {\partial {u_B}}}{1 \over m}$$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110558200.gif)
![$${{\partial l} \over {\partial \gamma }} = \sum\limits_{i = 1}^m {{{\partial l} \over {\partial {y_i}}}} {\widehat x_i}$$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110559205.gif)
![$${{\partial l} \over {\partial \beta }} = \sum\limits_{i = 1}^m {{{\partial l} \over {\partial {y_i}}}} $$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110559210.gif)
BN网络的训练和推断
按照BN方法,输入数据x会经过变化得到BN(x),然后可以通过随机梯度下降进行训练,标准化是在mini-batch上所以是非常高效的。 但是对于推断我们希望输出只取决于输入,而对于输入只有一个实例数据,无法得到mini-batch的其他实例,就无法求对应的均值和方差了。 可以通过从所有训练实例中获得的统计量来**代替**mini-batch中m个训练实例获得统计量均值和方差 我们对每个mini-batch做标准化,可以对记住每个mini-batch的B,然后得到全局统计量
![$$E[x] \leftarrow {E_B}[{\mu _B}]$$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110600217.gif)
![$$Var[x] \leftarrow {m \over {m - 1}}{E_B}[\sigma _B^2]$$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110601218.gif)
![$$\eqalign{ & y = \gamma {{x - E(x)} \over {\sqrt {Var[x] + \varepsilon } }} + \beta \cr & {\kern 1pt} {\kern 1pt} {\kern 1pt} {\kern 1pt} {\kern 1pt} {\kern 1pt} {\kern 1pt} {\kern 1pt} {\kern 1pt} = {\gamma \over {\sqrt {Var[x] + \varepsilon } }}x + (\beta - {{\gamma E[x]} \over {\sqrt {Var[x] + \varepsilon } }}) \cr} $$](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110602220.gif)
![这里写图片描述](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110602222.png)
实验
最后给出的实验可以看出使用BN的方式训练精准度很高而且很稳定。
![这里写图片描述](http://www.2cto.com/uploadfile/Collfiles/20170110/20170110110602224.png)