转载自:http://blog.csdn.net/qq_16055159/article/details/45068147
Solver<Dtype>::Solver(const SolverParameter& param)
功能:构造函数
步骤:初始化两个Net类,net_和test_net_,并调用Init()函数
输入:SolverParameter类型的param
输出:无
Solver<Dtype>::Solver(const string& param_file)
功能:构造函数
步骤:初始化两个Net类,net_和test_net_,并调用Init()函数
输入:string类型的param_file
输出:无
void Solver<Dtype>::Init(const SolverParameter& param)
功能:初始化网络
步骤:
1. 设置随机数种子
2. 申请一块Net空间以下面的构造函数进行初始化
param_file=train_net_
,net_指向这块空间
3. 如果有test_net,则申请一块Net空间,test_net_指向这块空间
输入:SolverParameter类型的param
输出:无
<code class="language-c++ hljs scss has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;">Net<Dtype><span class="hljs-value" style="box-sizing: border-box;">::Net(const string& param_file) { NetParameter param;</span> <span class="hljs-function" style="box-sizing: border-box;">ReadNetParamsFromTextFileOrDie(param_file, ¶m)</span>; <span class="hljs-function" style="box-sizing: border-box;">Init(param)</span>; }</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li></ul>
void Solver<Dtype>::Solve(const char* resume_file)
功能:训练网络
步骤:
1. 设置Caffe的mode(GPU还是CPU)
2. 如果是GPU且有GPU芯片的ID,则设置GPU
3. 设置当前阶段(TRAIN还是TEST/TRAIN)
4. 调用PreSolve函数:PreSolve()
5. 调用Restore函数:Restore(resume_file)
6. 调用一遍Test(),判断内存是否够
7. 对于每一次训练时的迭代(遍历整个网络):while (iter_++ < param_.max_iter())
- 计算loss:
loss = net_->ForwardBackward(bottom_vec)
其中:
<code class="hljs cpp has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;">*************** ForwardBackward() ************ Dtype ForwardBackward(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>* > & bottom) { Dtype loss; Forward(bottom, &loss); Backward(); <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> loss; } *************** Forward() *********** <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& Net<Dtype>::Forward( <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*> & bottom, Dtype* loss) { <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// Copy bottom to internal bottom</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> (<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> i = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; i < bottom.size(); ++i) net_input_blobs_[i]->CopyFrom(*bottom[i]){; } <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> ForwardPrefilled(loss); } *************** ForwardPrefilled() ************ <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& Net<Dtype>::ForwardPrefilled(Dtype* loss) { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (loss != NULL) { *loss = Dtype(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.</span>); } <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> (<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> i = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; i < layers_.size(); ++i) { <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// LOG(ERROR) << "Forwarding " << layer_names_[i];</span> Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]); <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (loss != NULL) { *loss += layer_loss;<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//对于非loss层都会返回0:return Dtype(0.);</span> } } <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> net_output_blobs_; } *************** Layer::Forward() ************ <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> Dtype Layer<Dtype>::Forward(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom, <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>* top) { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">switch</span> (Caffe::mode()) { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">case</span> Caffe::CPU: <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> Forward_cpu(bottom, top);<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//虚函数,不同层有不同层的计算方法</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">case</span> Caffe::GPU: <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> Forward_gpu(bottom, top); <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">default</span>: LOG(FATAL) << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"Unknown caffe mode."</span>; <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> Dtype(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>); } } *************** Backward() ************ <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> Net<Dtype>::Backward() { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> (<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> i = layers_.size() - <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>; i >= <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; --i) { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (layer_need_backward_[i]) { layers_[i]->Backward(top_vecs_[i], <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">true</span>, &bottom_vecs_[i]); } } }</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li><li style="box-sizing: border-box; padding: 0px 5px;">45</li><li style="box-sizing: border-box; padding: 0px 5px;">46</li><li style="box-sizing: border-box; padding: 0px 5px;">47</li><li style="box-sizing: border-box; padding: 0px 5px;">48</li><li style="box-sizing: border-box; padding: 0px 5px;">49</li><li style="box-sizing: border-box; padding: 0px 5px;">50</li><li style="box-sizing: border-box; padding: 0px 5px;">51</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li><li style="box-sizing: border-box; padding: 0px 5px;">45</li><li style="box-sizing: border-box; padding: 0px 5px;">46</li><li style="box-sizing: border-box; padding: 0px 5px;">47</li><li style="box-sizing: border-box; padding: 0px 5px;">48</li><li style="box-sizing: border-box; padding: 0px 5px;">49</li><li style="box-sizing: border-box; padding: 0px 5px;">50</li><li style="box-sizing: border-box; padding: 0px 5px;">51</li></ul>
2.调用ComputeUpdateValue函数:ComputeUpdateValue()
3. 输出loss
4. 达到test_interval时调用Test()
5. 达到snapshot时调用snapshot()
6. 调用Snapshot函数:Snapshot()
输入:char*类型的resume_file
输出:无
void Solver<Dtype>::Test()
功能:测试网络
输入:无
输出:无
步骤:
1. 设置当前阶段(TRAIN还是TEST/TEST)
2. 将test_net_指向net_,即对同一个网络操作
3. 对于每一次测试时的迭代:for (int i = 0; i < param_.test_iter(); ++i)
- 用下面语句给result赋值net_output_blobs_ //result是所有的输出层blob
同时得到这次测试的iter_loss
result = test_net_->Forward(bottom_vec, &iter_loss)
- 第一次测试时:
- 取每一个输出层的blob
result_vec = result[j]->cpu_data()
- 把每一个blob的数据(降为一维)存入一个vector–“test_score”
- 取每一个输出层的blob
- 不是第一次测试:
- 用
test_score[idx++] += result_vec[k]
而不是test_score.push_back(result_vec[k])
- 把输出层对应位置的blob值累加
test_score[idx++] += result_vec[k]
- 用
- 是否要输出Test loss
- 是否要输出test_score
- 设置当前阶段(TRAIN还是TEST/TRAIN)
void Solver<Dtype>::Snapshot()
功能:输出当前网络状态到一个文件中,不重要
输入:无
输出:无
void Solver<Dtype>::Restore(const char* state_file)
功能:从一个文件中读入网络状态,并可以从那个状态恢复,不重要
输入:文件名
输出:无
Dtype SGDSolver<Dtype>::GetLearningRate()
功能:得到学习率
步骤:
1. 得到学习率类型 const string& lr_policy = this->param_.lr_policy()
2. 判断学习率类型(注释有介绍)
3. 返回学习率
输入:无
输出:Dtype类型的rate
void SGDSolver<Dtype>::PreSolve()
功能:提前训练
步骤:
1. 将训练网络net_的参数读到net_params net_params = this->net_->params()
其中params_是一个存blob指针的vector
2. 清空历史残留值
3. 向history压入与网络的每一层blob相同大小的空间
输入:无
输出:无
void SGDSolver<Dtype>::ComputeUpdateValue()
功能:用随机梯度下降法计算更新值
输入:无
输出:无
步骤:
1. (所有的)读取网络参数net_params,网络学习速率 net_params_lr,
权值衰减net_params_weight_decay 读取学习速率rate
2. (当前层)读取动量,权值衰减
3. 如果是CPU:
对于每一次层:
- 计算local_rate,local_decay
- 调用caffe_cpu_axpby,caffe_axpy,caffe_copy函数:
<code class="language-c++ hljs scss has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;">caffe_cpu_axpby(net_params[param_id]-><span class="hljs-function" style="box-sizing: border-box;">count()</span>, local_rate, net_params[param_id]-><span class="hljs-function" style="box-sizing: border-box;">cpu_diff()</span>, momentum, history_[param_id]-><span class="hljs-function" style="box-sizing: border-box;">mutable_cpu_data()</span>)</span>;</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>
<code class="language-c++ hljs scss has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;">caffe_axpy(net_params[param_id]-><span class="hljs-function" style="box-sizing: border-box;">count()</span>, local_decay*local_rate, net_params[param_id]-><span class="hljs-function" style="box-sizing: border-box;">cpu_data()</span>,history_[param_id]-><span class="hljs-function" style="box-sizing: border-box;">mutable_cpu_data()</span>)</span>;</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>
<code class="hljs objectivec has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> caffe_cpu_axpby<<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>>(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> N, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span> alpha, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* X,<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span> beta, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* Y) { cblas_saxpby(N, alpha, X, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, beta, Y, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>); } 其中: <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> cblas_saxpby(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> N, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span> alpha, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* X,<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> incX, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span> beta, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* Y, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> incY) { cblas_sscal(N, beta, Y, incY); cblas_saxpy(N, alpha, X, incX, Y, incY); }</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li></ul>
caffe_cpu_axpby调用了cblas_saxpby,即调用了cblas_sscal和cblas_saxpy
<code class="hljs cs has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> caffe_axpy<<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>>(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> N, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span> alpha, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* X,<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* Y) { cblas_saxpy(N, alpha, X, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, Y, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>); }</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li></ul>
caffe_axpy调用了cblas_saxpy,即调用了cblas_saxpy
所以caffe_cpu_axpby比caffe_axpy多输入了一个beta参数,多调用了cblas_sscal(N, beta, Y, incY);
4. GPU同理
void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state)
略
void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state)
略