最近做人脸项目,打算用Batch Normalization 优化网络模型,看到大神博客写的特别好,先转载一下。
本文转载于:http://blog.csdn.net/shuzfan/article/details/50723877
目录
本次所讲的内容为Batch Normalization,简称BN,来源于《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》,是一篇很好的paper。
1-Motivation
作者认为:网络训练过程中参数不断改变导致后续每一层输入的分布也发生变化,而学习的过程又要使每一层适应输入的分布,因此我们不得不降低学习率、小心地初始化。作者将分布发生变化称之为 internal covariate shift。
大家应该都知道,我们一般在训练网络的时会将输入减去均值,还有些人甚至会对输入做白化等操作,目的是为了加快训练。为什么减均值、白化可以加快训练呢,这里做一个简单地说明:
首先,图像数据是高度相关的,假设其分布如下图a所示(简化为2维)。由于初始化的时候,我们的参数一般都是0均值的,因此开始的拟合y=Wx+b,基本过原点附近,如图b红色虚线。因此,网络需要经过多次学习才能逐步达到如紫色实线的拟合,即收敛的比较慢。如果我们对输入数据先作减均值操作,如图c,显然可以加快学习。更进一步的,我们对数据再进行去相关操作,使得数据更加容易区分,这样又会加快训练,如图d。
白化的方式有好几种,常用的有PCA白化:即对数据进行PCA操作之后,在进行方差归一化。这样数据基本满足0均值、单位方差、弱相关性。作者首先考虑,对每一层数据都使用白化操作,但分析认为这是不可取的。因为白化需要计算协方差矩阵、求逆等操作,计算量很大,此外,反向传播时,白化操作不一定可导。于是,作者采用下面的Normalization方法。
2-Normalization via Mini-Batch Statistics
数据归一化方法很简单,就是要让数据具有0均值和单位方差,如下式:
但是作者又说如果简单的这么干,会降低层的表达能力。比如下图,在使用sigmoid激活函数的时候,如果把数据限制到0均值单位方差,那么相当于只使用了激活函数中近似线性的部分,这显然会降低模型表达能力。
为此,作者又为BN增加了2个参数,用来保持模型的表达能力。
于是最后的输出为:
上述公式中用到了均值E和方差Var,需要注意的是理想情况下E和Var应该是针对整个数据集的,但显然这是不现实的。因此,作者做了简化,用一个Batch的均值和方差作为对整个数据集均值和方差的估计。
整个BN的算法如下:
求导的过程也非常简单,有兴趣地可以自己再推导一遍或者直接参见原文。
测试
实际测试网络的时候,我们依然会应用下面的式子:
特别注意: 这里的均值和方差已经不是针对某一个Batch了,而是针对整个数据集而言。因此,在训练过程中除了正常的前向传播和反向求导之外,我们还要记录每一个Batch的均值和方差,以便训练完成之后按照下式计算整体的均值和方差:
BN before or after Activation
作者在文章中说应该把BN放在激活函数之前,这是因为Wx+b具有更加一致和非稀疏的分布。但是也有人做实验表明放在激活函数后面效果更好。这是实验链接,里面有很多有意思的对比实验:https://github.com/ducha-aiki/caffenet-benchmark
3-Experiments
作者在文章中也做了很多实验对比,我这里就简单说明2个。
下图a说明,BN可以加速训练。图b和c则分别展示了训练过程中输入数据分布的变化情况。
下表是一个实验结果的对比,需要注意的是在使用BN的过程中,作者发现Sigmoid激活函数比Relu效果要好。
BatchNormalization是神经网络中常用的参数初始化的方法。其算法流程图如下:
我们可以把这个流程图以门电路的形式展开,方便进行前向传播和后向传播:
那么前向传播非常简单,直接给出代码:
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">batchnorm_forward</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(x, gamma, beta, eps)</span>:</span>
N, D = x.shape
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#为了后向传播求导方便,这里都是分步进行的</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step1: 计算均值</span>
mu = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.</span>/N * np.sum(x, axis = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>)
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step2: 减均值</span>
xmu = x - mu
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step3: 计算方差</span>
sq = xmu ** <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>
var = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.</span>/N * np.sum(sq, axis = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>)
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step4: 计算x^的分母项</span>
sqrtvar = np.sqrt(var + eps)
ivar = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.</span>/sqrtvar
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step5: normalization->x^</span>
xhat = xmu * ivar
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step6: scale and shift</span>
gammax = gamma * xhat
out = gammax + beta
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#存储中间变量</span>
cache = (xhat,gamma,xmu,ivar,sqrtvar,var,eps)
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> out, cache</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li></ul>
反向传播则是求导的过程,这里特别要小心,由于门电路中有多个支路,求导时要进行加和。
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">batchnorm_backward</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(dout, cache)</span>:</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#解压中间变量</span>
xhat,gamma,xmu,ivar,sqrtvar,var,eps = cache
N,D = dout.shape
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step6</span>
dbeta = np.sum(dout, axis=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>)
dgammax = dout
dgamma = np.sum(dgammax*xhat, axis=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>)
dxhat = dgammax * gamma
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step5</span>
divar = np.sum(dxhat*xmu, axis=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>)
dxmu1 = dxhat * ivar <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#注意这是xmu的一个支路</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step4</span>
dsqrtvar = -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.</span> /(sqrtvar**<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>) * divar
dvar = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.5</span> * <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.</span> /np.sqrt(var+eps) * dsqrtvar
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step3</span>
dsq = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.</span> /N * np.ones((N,D)) * dvar
dxmu2 = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span> * xmu * dsq <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#注意这是xmu的第二个支路</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step2</span>
dx1 = (dxmu1 + dxmu2) 注意这是x的一个支路
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step1</span>
dmu = -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span> * np.sum(dxmu1+dxmu2, axis=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>)
dx2 = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.</span> /N * np.ones((N,D)) * dmu 注意这是x的第二个支路
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">#step0 done!</span>
dx = dx1 + dx2
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> dx, dgamma, dbeta</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li></ul>
要注意的就是求导时遇到多个支路的情况要进行累加。表达式复杂的话还是分步进行比较不容易出错。