Caffe中权值是怎么更新的

网址:http://blog.csdn.net/mounty_fsc/article/details/51588773


(Caffe,LeNet)权值更新(七)

在Solver::ApplyUpdate()函数中,根据反向传播阶段计算的loss关于网络权值的偏导,使用配置的学习策略,更新网络权值从而完成本轮学习

1 模型优化

1.1 损失函数

损失函数 L(W) 可由经验损失加正则化项得到,如下,其中 X(i) 为输入样本; fW 为某样本的损失函数; N 为mini-batch的样本数量; r(W) 为以权值为 λ 的正则项。

L(W)1NNifW(X(i))+λr(W)

在caffe中,可以分为三个阶段:

  1. 前向计算阶段,这个阶段计算 fW
  2. 反向传播阶段,这个阶段计算 fW
  3. 权值更新阶段,这个阶段通过 fW,r(W) 等计算 ΔW 从而更新 W

1.2 随机梯度下降

在lenet中,solver的类型为SGD(Stochastic gradient descent)

SGD通过以下公式对权值进行更新:

Wt+1=Wt+Vt+1  
Vt+1=μVtαL(Wt)

其中, Wt+1 为第 t+1 轮的权值; Vt+1 为第 t+1 轮的更新(也可以写作 ΔWt+1 ); μ 为上一轮更新的权重; α 为学习率; L(Wt) 为loss对权值的求导

2 代码分析

2.1 ApplyUpdate

<code class="language-c++ hljs lasso has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-literal" style="color: rgb(0, 102, 102); box-sizing: border-box;">void</span> SGDSolver<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;"><</span>Dtype<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">></span><span class="hljs-tag" style="color: rgb(0, 102, 102); box-sizing: border-box;">::ApplyUpdate</span>() {
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 获取该轮迭代的学习率(learning rate)</span>
  Dtype rate <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">=</span> GetLearningRate();

  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 对每一层网络的权值进行更新</span>
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 在lenet中,只有`conv1`,`conv2`,`ip1`,`ip2`四层有参数</span>
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 每层分别有参数与偏置参数两项参数</span>
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 因而`learnable_params_`的size为8.</span>
  for (int param_id <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">=</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; param_id <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;"><</span> this<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>net_<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>learnable_params()<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">.</span>size();
       <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">++</span>param_id) {
    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 归一化,iter_size为1不需要,因而lenet不需要</span>
    Normalize(param_id);
    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 正则化</span>
    Regularize(param_id);
    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 计算更新值\delta w</span>
    ComputeUpdateValue(param_id, rate);
  }
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 更新权值</span>
  this<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>net_<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>Update();
}


</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li></ul>

说明:

  1. lenet中学习参数设置可从lenet_solver.prototxt中查到

    <code class="language-c++ hljs avrasm has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">
    <span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;"># The base learning rate, momentum and the weight decay of the network.</span>
    
    <span class="hljs-label" style="box-sizing: border-box;">base_lr:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.01</span>
    <span class="hljs-label" style="box-sizing: border-box;">momentum:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.9</span>
    <span class="hljs-label" style="box-sizing: border-box;">weight_decay:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.0005</span>
    
    <span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;"># The learning rate policy</span>
    
    <span class="hljs-label" style="box-sizing: border-box;">lr_policy:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"inv"</span>
    <span class="hljs-label" style="box-sizing: border-box;">gamma:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.0001</span>
    <span class="hljs-label" style="box-sizing: border-box;">power:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.75</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li></ul>
  2. 获取学习率函数ApplyUpdate代码此处不给出,查看注释(以及caffe.proto)可知有如下学习率获取策略。在Lenet中采用的是inv的策略,是一种没一轮迭代学习率都改变的策略。

    <code class="language-c++ hljs cs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// The learning rate decay policy. The currently implemented learning rate</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// policies are as follows:</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - fixed: always return base_lr.</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - step: return base_lr * gamma ^ (floor(iter / step))</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - exp: return base_lr * gamma ^ iter</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - inv: return base_lr * (1 + gamma * iter) ^ (- power)</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - multistep: similar to step but it allows non uniform steps defined by</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//      stepvalue</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - poly: the effective learning rate follows a polynomial decay, to be</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - sigmoid: the effective learning rate follows a sigmod decay</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// where base_lr, max_iter, gamma, step, stepvalue and power are defined</span>
      <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// in the solver parameter protocol buffer, and iter is the current iteration.</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li></ul>

2.2 Regularize

该函数实际执行以下公式

losswij=decaywij+losswij

代码如下:

<code class="language-c++ hljs cpp has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> SGDSolver<Dtype>::Regularize(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> param_id) {
  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& net_params = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->learnable_params();
  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>></span>& net_params_weight_decay =
      <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->params_weight_decay();
  Dtype weight_decay = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->param_.weight_decay();
  <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">string</span> regularization_type = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->param_.regularization_type();
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// local_decay = 0.0005 in lenet</span>
  Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

  ...
      <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (regularization_type == <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"L2"</span>) {
        <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// axpy means ax_plus_y. i.e., y = a*x + y</span>
        caffe_axpy(net_params[param_id]->count(),
            local_decay,
            net_params[param_id]->cpu_data(),
            net_params[param_id]->mutable_cpu_diff());
      } 
  ...
}</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li></ul>

2.3 ComputeUpdateValue

该函数实际执行以下公式 
vij=lr_ratelosswij+momentumvij  
losswij=vij

代码如下:

<code class="language-c++ hljs cpp has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> SGDSolver<Dtype>::ComputeUpdateValue(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> param_id, Dtype rate) {
  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& net_params = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->learnable_params();
  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>></span>& net_params_lr = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->params_lr();
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// momentum = 0.9 in lenet</span>
  Dtype momentum = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->param_.momentum();
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// local_rate = lr_mult * global_rate</span>
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// lr_mult为该层学习率乘子,在lenet_train_test.prototxt中设置</span>
  Dtype local_rate = rate * net_params_lr[param_id];

  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// Compute the update to history, then copy it to the parameter diff.</span>

  ...
    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// axpby means ax_plus_by. i.e., y = ax + by</span>
    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 计算新的权值更新变化值 \delta w,结果保存在历史权值变化中</span>
    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
              net_params[param_id]->cpu_diff(), momentum,
              history_[param_id]->mutable_cpu_data());

    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 从历史权值变化中把变化值 \delta w 保存到历史权值中diff中</span>
    caffe_copy(net_params[param_id]->count(),
        history_[param_id]->cpu_data(),
        net_params[param_id]->mutable_cpu_diff());
   ... 
}</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li></ul>

2.4 net_->Update

实际执行以下公式: 
wij=wij+(1)losswij

<code class="language-c++ hljs cpp has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">caffe_axpy<Dtype>(count_, Dtype(-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>),
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">static_cast</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> Dtype*>(diff_->cpu_data()),
        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">static_cast</span><Dtype*>(data_->mutable_cpu_data()));</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li></ul>

参考文献:

[1]. http://caffe.berkeleyvision.org/tutorial/solver.html


  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值