原文地址: http://blog.csdn.net/shuzfan/article/details/51338178
“Xavier”初始化方法是一种很有效的神经网络初始化方法,方法来源于2010年的一篇论文《Understanding the difficulty of training deep feedforward neural networks》,可惜直到近两年,这个方法才逐渐得到更多人的应用和认可。
为了使得网络中信息更好的流动,每一层输出的方差应该尽量相等。
基于这个目标,现在我们就去推导一下:每一层的权重应该满足哪种条件。
文章先假设的是线性激活函数,而且满足0点处导数为1,即
现在我们先来分析一层卷积:
其中ni表示输入个数。
根据概率统计知识我们有下面的方差公式:
特别的,当我们假设输入和权重都是0均值时(目前有了BN之后,这一点也较容易满足),上式可以简化为:
进一步假设输入x和权重w独立同分布,则有:
于是,为了保证输入与输出方差一致,则应该有:
对于一个多层的网络,某一层的方差可以用累积的形式表达:
特别的,反向传播计算梯度时同样具有类似的形式:
综上,为了保证前向传播和反向传播时每一层的方差一致,应满足:
但是,实际当中输入与输出的个数往往不相等,于是为了均衡考量,最终我们的权重方差应满足:
———————————————————————————————————————
———————————————————————————————————————
学过概率统计的都知道 [a,b] 间的均匀分布的方差为:
因此,Xavier初始化的实现就是下面的均匀分布:
——————————————————————————————————————————
———————————————————————————————————————————
下面,我们来看一下caffe中具体是怎样实现的,代码位于include/caffe/filler.hpp文件中。
<code class="language-C++ hljs haskell has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-title" style="box-sizing: border-box;">template</span> <typename <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span>> <span class="hljs-class" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">class</span> <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">XavierFiller</span> : public <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Filler</span><<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span>> { public: explicit <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">XavierFiller</span><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">const</span> <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">FillerParameter</span>& <span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">param</span>)</span> : <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Filler</span><<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span>><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">param</span>)</span> {} virtual void <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Fill</span><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Blob</span><<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span>>* <span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">blob</span>)</span> { <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">CHECK</span><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">blob</span>-><span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">count</span>()</span>); int fan_in = blob->count<span class="hljs-container" style="box-sizing: border-box;">()</span> / blob->num<span class="hljs-container" style="box-sizing: border-box;">()</span>; int fan_out = blob->count<span class="hljs-container" style="box-sizing: border-box;">()</span> / blob->channels<span class="hljs-container" style="box-sizing: border-box;">()</span>; <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span> n = fan_in; // default to fan_in if <span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">this</span>-><span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">filler_param_</span>.<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">variance_norm</span>()</span> == <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">FillerParameter_VarianceNorm_AVERAGE</span>) { n = <span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">fan_in</span> + <span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">fan_out</span>)</span> / <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span><span class="hljs-container" style="box-sizing: border-box;">(2)</span>; } else if <span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">this</span>-><span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">filler_param_</span>.<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">variance_norm</span>()</span> == <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">FillerParameter_VarianceNorm_FAN_OUT</span>) { n = fan_out; } <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span> scale = sqrt<span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype(3)</span> / <span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">n</span>)</span>; caffe_rng_uniform<<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span>><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">blob</span>-><span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">count</span>()</span>, -scale, scale, blob->mutable_cpu_data<span class="hljs-container" style="box-sizing: border-box;">()</span>); <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">CHECK_EQ</span><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">this</span>-><span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">filler_param_</span>.<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">sparse</span>()</span>, -1) << "<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Sparsity</span> not supported by this <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Filler</span>."; } };</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li></ul>
由上面可以看出,caffe的Xavier实现有三种选择
(1) 默认情况,方差只考虑输入个数:
(2) FillerParameter_VarianceNorm_FAN_OUT,方差只考虑输出个数:
(3) FillerParameter_VarianceNorm_AVERAGE,方差同时考虑输入和输出个数:
之所以默认只考虑输入,我个人觉得是因为前向信息的传播更重要一些