网址:http://blog.csdn.net/mounty_fsc/article/details/51588773
(Caffe,LeNet)权值更新(七)
在Solver::ApplyUpdate()函数中,根据反向传播阶段计算的loss关于网络权值的偏导,使用配置的学习策略,更新网络权值从而完成本轮学习
1 模型优化
1.1 损失函数
损失函数 L(W) 可由经验损失加正则化项得到,如下,其中 X(i) 为输入样本; fW 为某样本的损失函数; N 为mini-batch的样本数量; r(W) 为以权值为 λ 的正则项。
L(W)≈1N∑NifW(X(i))+λr(W)
在caffe中,可以分为三个阶段:
- 前向计算阶段,这个阶段计算 fW
- 反向传播阶段,这个阶段计算 ∇fW
- 权值更新阶段,这个阶段通过 ∇fW,∇r(W) 等计算 ΔW 从而更新 W
1.2 随机梯度下降
在lenet中,solver的类型为SGD(Stochastic gradient descent)
SGD通过以下公式对权值进行更新:
Wt+1=Wt+Vt+1
Vt+1=μVt−α∇L(Wt)
其中, Wt+1 为第 t+1 轮的权值; Vt+1 为第 t+1 轮的更新(也可以写作 ΔWt+1 ); μ 为上一轮更新的权重; α 为学习率; ∇L(Wt) 为loss对权值的求导
2 代码分析
2.1 ApplyUpdate
<code class="language-c++ hljs lasso has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-literal" style="color: rgb(0, 102, 102); box-sizing: border-box;">void</span> SGDSolver<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;"><</span>Dtype<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">></span><span class="hljs-tag" style="color: rgb(0, 102, 102); box-sizing: border-box;">::ApplyUpdate</span>() { <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 获取该轮迭代的学习率(learning rate)</span> Dtype rate <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">=</span> GetLearningRate(); <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 对每一层网络的权值进行更新</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 在lenet中,只有`conv1`,`conv2`,`ip1`,`ip2`四层有参数</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 每层分别有参数与偏置参数两项参数</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 因而`learnable_params_`的size为8.</span> for (int param_id <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">=</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; param_id <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;"><</span> this<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>net_<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>learnable_params()<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">.</span>size(); <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">++</span>param_id) { <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 归一化,iter_size为1不需要,因而lenet不需要</span> Normalize(param_id); <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 正则化</span> Regularize(param_id); <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 计算更新值\delta w</span> ComputeUpdateValue(param_id, rate); } <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 更新权值</span> this<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>net_<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>Update(); } </code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li></ul>
说明:
-
lenet中学习参数设置可从
lenet_solver.prototxt
中查到<code class="language-c++ hljs avrasm has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"> <span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;"># The base learning rate, momentum and the weight decay of the network.</span> <span class="hljs-label" style="box-sizing: border-box;">base_lr:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.01</span> <span class="hljs-label" style="box-sizing: border-box;">momentum:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.9</span> <span class="hljs-label" style="box-sizing: border-box;">weight_decay:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.0005</span> <span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;"># The learning rate policy</span> <span class="hljs-label" style="box-sizing: border-box;">lr_policy:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"inv"</span> <span class="hljs-label" style="box-sizing: border-box;">gamma:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.0001</span> <span class="hljs-label" style="box-sizing: border-box;">power:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.75</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li></ul>
-
获取学习率函数ApplyUpdate代码此处不给出,查看注释(以及caffe.proto)可知有如下学习率获取策略。在Lenet中采用的是
inv
的策略,是一种没一轮迭代学习率都改变的策略。<code class="language-c++ hljs cs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// The learning rate decay policy. The currently implemented learning rate</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// policies are as follows:</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// - fixed: always return base_lr.</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// - step: return base_lr * gamma ^ (floor(iter / step))</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// - exp: return base_lr * gamma ^ iter</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// - inv: return base_lr * (1 + gamma * iter) ^ (- power)</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// - multistep: similar to step but it allows non uniform steps defined by</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// stepvalue</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// - poly: the effective learning rate follows a polynomial decay, to be</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// - sigmoid: the effective learning rate follows a sigmod decay</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// where base_lr, max_iter, gamma, step, stepvalue and power are defined</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// in the solver parameter protocol buffer, and iter is the current iteration.</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li></ul>
2.2 Regularize
该函数实际执行以下公式
∂loss∂wij=decay∗wij+∂loss∂wij
代码如下:
<code class="language-c++ hljs cpp has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> SGDSolver<Dtype>::Regularize(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> param_id) { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& net_params = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->learnable_params(); <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>></span>& net_params_weight_decay = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->params_weight_decay(); Dtype weight_decay = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->param_.weight_decay(); <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">string</span> regularization_type = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->param_.regularization_type(); <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// local_decay = 0.0005 in lenet</span> Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; ... <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (regularization_type == <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"L2"</span>) { <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// axpy means ax_plus_y. i.e., y = a*x + y</span> caffe_axpy(net_params[param_id]->count(), local_decay, net_params[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); } ... }</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li></ul>
2.3 ComputeUpdateValue
该函数实际执行以下公式
vij=lr_rate∗∂loss∂wij+momentum∗vij
∂loss∂wij=vij
代码如下:
<code class="language-c++ hljs cpp has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> SGDSolver<Dtype>::ComputeUpdateValue(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> param_id, Dtype rate) { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& net_params = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->learnable_params(); <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>></span>& net_params_lr = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->params_lr(); <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// momentum = 0.9 in lenet</span> Dtype momentum = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->param_.momentum(); <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// local_rate = lr_mult * global_rate</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// lr_mult为该层学习率乘子,在lenet_train_test.prototxt中设置</span> Dtype local_rate = rate * net_params_lr[param_id]; <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// Compute the update to history, then copy it to the parameter diff.</span> ... <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// axpby means ax_plus_by. i.e., y = ax + by</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 计算新的权值更新变化值 \delta w,结果保存在历史权值变化中</span> caffe_cpu_axpby(net_params[param_id]->count(), local_rate, net_params[param_id]->cpu_diff(), momentum, history_[param_id]->mutable_cpu_data()); <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 从历史权值变化中把变化值 \delta w 保存到历史权值中diff中</span> caffe_copy(net_params[param_id]->count(), history_[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); ... }</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li></ul>
2.4 net_->Update
实际执行以下公式:
wij=wij+(−1)∗∂loss∂wij
<code class="language-c++ hljs cpp has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">caffe_axpy<Dtype>(count_, Dtype(-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>), <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">static_cast</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> Dtype*>(diff_->cpu_data()), <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">static_cast</span><Dtype*>(data_->mutable_cpu_data()));</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li></ul>
参考文献:
[1]. http://caffe.berkeleyvision.org/tutorial/solver.html