Caffe源码中Solver文件分析

参考:https://www.2cto.com/kf/201703/615653.html
Caffe源码中Solver文件分析

Caffe源码中Solver文件分析:Caffe源码中有一些重要的头文件,这里介绍下include/caffe/solver.hpp文件的内容:

1.include文件:

:此文件的介绍可以参考:https://blog.csdn.net/fengbingchun/article/details/62423060

2.模板类Solver:虚基类

3.模板类WorkerSolver:继承父类Solver,用于多GPU训练时仅计算梯度

4.模板类SGDSolver:继承父类Solver

5.模板类NesterovSolver:继承SGDSolver

6.模板类AdaGradSolver:继承SGDSolver

7.模板类RMSPropSolver:继承SGDSolver

8.模板类AdaDeltaSolver:继承SGDSolver

9.模板类AdamSolver:继承SGDSolver

10.函数GetSolver:new solver对象

Solver通过协调Net的前向推断计算和反向梯度计算(forward inference and backward gradients),来对参数进行更新,从而达到减少loss的目的。Caffe模型的学习被分为两个部分:由Solver进行优化、更新参数,由Net计算出loss和gradient。

solver.prototxt是一个配置文件用来告知Caffe怎样对网络进行训练。

有了Net就可以进行神经网络的前后向传播计算了,但是还缺少神经网络的训练和预测功能,Solver类进一步封装了训练和预测相关的一些功能。Solver定义了针对Net网络模型的求解方法,记录神经网络的训练过程,保存神经网络模型参数,中断并恢复网络的训练过程。自定义Solver能够实现不同的神经网络求解方式。

Caffe支持的solvers包括:

(1)、Stochastic Gradient Descent(type: “SGD”)即随机梯度下降:利用负梯度和上一次权重的更新值的线性组合来更新权重。学习率(learning rate)是负梯度的权重。动量是上一次更新值的权重。一般将学习速率初始化为0.01,然后在训练(training)中当loss达到稳定时,将学习速率除以一个常数(例如10),将这个过程重复多次。对于动量一般设置为0.9,动量使weight得更新更为平缓,使学习过程更为稳定、快速。

(2)、AdaDelta(type:“AdaDelta”):是一种”鲁棒的学习率方法”,同SGD一样是一种基于梯度的优化方法。

(3)、Adaptive Gradient(type: “AdaGrad”)即自适应梯度下降,与随机梯度下降一样是基于梯度的优化方法。

(4)、Adam(type:“Adam”):也是一种基于梯度的优化方法。它包含一对自适应时刻估计变量,可以看做是AdaGrad的一种泛化形式。

(5)、Nesterov’s Accelerated Gradient(type: “Nesterov”):Nesterov提出的加速梯度下降(Nesterov’s accelerated gradient)是凸优化的一种最优算法,其收敛速度可以达到O(1/t^2),而不是O(1/t)。尽管在使用Caffe训练深度神经网络时很难满足O(1/t^2)收敛条件,但实际中NAG对于某些特定结构的深度学习模型仍是一个非常有效的方法。

(6)、RMSprop(type:“RMSProp”):是一种基于梯度的优化方法(同SGD类似)。

Solver:

(1)、用于优化过程的记录、创建训练网络(用于学习)和测试网络(用于评估);

(2)、通过forward和backward过程来迭代地优化和更新参数;

(3)、周期性地用测试网络评估模型性能;

(4)、在优化过程中记录模型和solver状态的快照(snapshot)。

每一次迭代过程中:

(1)、调用Net的前向过程计算出输出和loss;

(2)、调用Net的反向过程计算出梯度(loss对每层的权重w和偏置b求导);

(3)、根据下面所讲的Solver方法,利用梯度更新参数;

(4)、根据学习率(learning rate),历史数据和求解方法更新solver的状态,使权重从初始化状态逐步更新到最终的学习到的状态。

Solvers的运行模式有CPU/GPU两种模式。

Solver方法:用于最小化损失(loss)值。给定一个数据集D,优化的目标是D中所有数据损失的均值,即平均损失,取得最小值。

注:以上关于Solver内容的介绍主要摘自由CaffeCN社区翻译的《Caffe官方教程中译本》。

文件的详细介绍如下:

?

#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_
 
#include <string>
#include <vector>
 
#include "caffe/net.hpp"
 
namespace caffe {
 
/**
  * @brief An interface for classes that perform optimization on Net%s.
  *
  * Requires implementation of ApplyUpdate to compute a parameter update
  * given the current state of the Net parameters.
  */
template <typename dtype= "" >
class Solver { // Solver模板类,虚基类
  public :
// 显示构造函数, 内部会调用Init函数
   explicit Solver( const SolverParameter& param, const Solver* root_solver = NULL);
   explicit Solver( const string& param_file, const Solver* root_solver = NULL);
// 成员变量赋值,包括param_、iter_、current_step_,并调用InitTrainNet和InitTestNets函数
   void Init( const SolverParameter& param);
// 为成员变量net_赋值
   void InitTrainNet();
// 为成员变量test_nets_赋值
   void InitTestNets();
   // The main entry of the solver function. In default, iter will be zero. Pass
   // in a non-zero iter number to resume training for a pre-trained net.
// 依次调用函数Restore、Step、Snapshot,然后执行net_的前向传播函数ForwardPrefilled,最后调用TestAll函数
   virtual void Solve( const char * resume_file = NULL);
   inline void Solve( const string resume_file) { Solve(resume_file.c_str()); }
// 反复执行net前向传播反向传播计算,期间会调用函数TestAll、ApplyUpdate、Snapshot及类Callback两个成员函数
   void Step( int iters);
   // The Restore method simply dispatches to one of the
   // RestoreSolverStateFrom___ protected methods. You should implement these
   // methods to restore the state from the appropriate snapshot type.
// 加载已有的模型
   void Restore( const char * resume_file);
// 虚析构函数
   virtual ~Solver() {}
  // 获得slover parameter
   inline const SolverParameter& param() const { return param_; }
// 获得train Net
   inline shared_ptr<net<dtype> > net() { return net_; }
// 获得test Net
   inline const vector<shared_ptr<net<dtype> > >& test_nets() {
     return test_nets_;
   }
// 获得当前的迭代数
   int iter() { return iter_; }
   // Invoked at specific points during an iteration
// 内部Callback类,仅在多卡GPU模式下使用
   class Callback {
    protected :
     virtual void on_start() = 0 ;
     virtual void on_gradients_ready() = 0 ;
 
     template <typename t= "" >
     friend class Solver;
   };
// 获得Callback
   const vector<callback*>& callbacks() const { return callbacks_; }
// 添加一个Callback
   void add_callback(Callback* value) { callbacks_.push_back(value); }
 
  protected :
   // Make and apply the update value for the current iteration.
// 更新net的权值和偏置
   virtual void ApplyUpdate() = 0 ;
   // The Solver::Snapshot function implements the basic snapshotting utility
   // that stores the learned net. You should implement the SnapshotSolverState()
   // function that produces a SolverState protocol buffer that needs to be
   // written to disk together with the learned net.
// 快照,内部会调用SnapshotToBinaryProto或SnapshotToHDF5、SnapshotSolverState函数
   void Snapshot();
// 获取快照文件名
   string SnapshotFilename( const string extension);
// 写proto到.caffemodel
   string SnapshotToBinaryProto();
// 写proto到HDF5文件
   string SnapshotToHDF5();
   // The test routine
// 内部会循环调用Test函数
   void TestAll();
// 执行测试网络,net前向传播
   void Test( const int test_net_id = 0 );
// 存储snapshot solver state
   virtual void SnapshotSolverState( const string& model_filename) = 0 ;
// 读HDF5文件到solver state
   virtual void RestoreSolverStateFromHDF5( const string& state_file) = 0 ;
// 读二进制文件.solverstate到solver state
   virtual void RestoreSolverStateFromBinaryProto( const string& state_file) = 0 ;
// dummy function,只有声明没有实现
   void DisplayOutputBlobs( const int net_id);
 
// Caffe中类的成员变量名都带有后缀"_",这样就容易区分临时变量和类成员变量
   SolverParameter param_; // solver parameter
   int iter_; // 当前的迭代数
   int current_step_; //
   shared_ptr<net<dtype> > net_; // train net
   vector<shared_ptr<net<dtype> > > test_nets_; // test net
   vector<callback*> callbacks_; // Callback
 
   // The root solver that holds root nets (actually containing shared layers)
   // in data parallelism
   const Solver* const root_solver_;
 
// 禁止使用Solver类的拷贝和赋值操作
   DISABLE_COPY_AND_ASSIGN(Solver);
};
 
/**
  * @brief Solver that only computes gradients, used as worker
  *        for multi-GPU training.
  */
template <typename dtype= "" >
class WorkerSolver : public Solver<dtype> { // 模板类WorkerSolver,继承父类Solver
  public :
// 显示构造函数
   explicit WorkerSolver( const SolverParameter& param, const Solver<dtype>* root_solver = NULL)
       : Solver<dtype>(param, root_solver) {}
 
  protected :
   void ApplyUpdate() {}
   void SnapshotSolverState( const string& model_filename) {
     LOG(FATAL) << "Should not be called on worker solver." ;
   }
   void RestoreSolverStateFromBinaryProto( const string& state_file) {
     LOG(FATAL) << "Should not be called on worker solver." ;
   }
   void RestoreSolverStateFromHDF5( const string& state_file) {
     LOG(FATAL) << "Should not be called on worker solver." ;
   }
};
 
/**
  * @brief Optimizes the parameters of a Net using
  *        stochastic gradient descent (SGD) with momentum.
  */
template <typename dtype= "" >
class SGDSolver : public Solver<dtype> { // 模板类SGDSolver,继承父类Solver
  public :
// 显示构造函数,调用PreSolve函数
   explicit SGDSolver( const SolverParameter& param) : Solver<dtype>(param) { PreSolve(); }
   explicit SGDSolver( const string& param_file) : Solver<dtype>(param_file) { PreSolve(); }
// 获取history数据
   const vector<shared_ptr<blob<dtype> > >& history() { return history_; }
 
  protected :
// 成员变量history_, update_, temp_初始化
   void PreSolve();
// 获取学习率
   Dtype GetLearningRate();
// 内部会调用ClipGradients、Normalize、Regularize、ComputeUpdateValue,更新net权值和偏置
   virtual void ApplyUpdate();
// 调用caffe_scal函数
   virtual void Normalize( int param_id);
// 调用caffe_axpy函数
   virtual void Regularize( int param_id);
// 计算并更新相应Blob值,调用caffe_cpu_axpby和caffe_copy函数
   virtual void ComputeUpdateValue( int param_id, Dtype rate);
// clip parameter gradients to that L2 norm,如果梯度值过大,就会对梯度做一个修剪,
// 对所有的参数乘以一个缩放因子,使得所有参数的平方和不超过参数中设定的梯度总值
   virtual void ClipGradients();
// 存储snapshot solver state,内部会掉用SnapshotSolverStateToBinaryProto或SnapshotSolverStateToHDF5函数
   virtual void SnapshotSolverState( const string& model_filename);
// 写solver state到二进制文件.solverstate
   virtual void SnapshotSolverStateToBinaryProto( const string& model_filename);
// 写solver state到HDF5
   virtual void SnapshotSolverStateToHDF5( const string& model_filename);
   // 读HDF5文件到solver state
   virtual void RestoreSolverStateFromHDF5( const string& state_file);
   // 读二进制文件.solverstate到solver state
   virtual void RestoreSolverStateFromBinaryProto( const string& state_file);
   // history maintains the historical momentum data.
   // update maintains update related data and is not needed in snapshots.
   // temp maintains other information that might be needed in computation
   //   of gradients/updates and is not needed in snapshots
// Caffe中类的成员变量名都带有后缀"_",这样就容易区分临时变量和类成员变量
   vector<shared_ptr<blob<dtype> > > history_, update_, temp_;
 
// 禁止使用SGDSolver类的拷贝和赋值操作
   DISABLE_COPY_AND_ASSIGN(SGDSolver);
};
 
template <typename dtype= "" >
class NesterovSolver : public SGDSolver<dtype> { // 模板类NesterovSolver,继承SGDSolver
  public :
// 显示构造函数
   explicit NesterovSolver( const SolverParameter& param) : SGDSolver<dtype>(param) {}
   explicit NesterovSolver( const string& param_file) : SGDSolver<dtype>(param_file) {}
 
  protected :
// 计算并更新相应Blob值,调用caffe_cpu_axpby和caffe_copy函数
   virtual void ComputeUpdateValue( int param_id, Dtype rate);
 
// 禁止使用NesterovSolver类的拷贝和赋值操作
   DISABLE_COPY_AND_ASSIGN(NesterovSolver);
};
 
template <typename dtype= "" >
class AdaGradSolver : public SGDSolver<dtype> { // 模板类AdaGradSolver,继承SGDSolver
  public :
// 显示构造函数,调用constuctor_sanity_check函数
   explicit AdaGradSolver( const SolverParameter& param) : SGDSolver<dtype>(param) { constructor_sanity_check(); }
   explicit AdaGradSolver( const string& param_file) : SGDSolver<dtype>(param_file) { constructor_sanity_check(); }
 
  protected :
// 计算并更新相应Blob值
   virtual void ComputeUpdateValue( int param_id, Dtype rate);
   void constructor_sanity_check() {
     CHECK_EQ( 0 , this ->param_.momentum())
         << "Momentum cannot be used with AdaGrad." ;
   }
 
// 禁止使用AdaGradSolver类的拷贝和赋值操作
   DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};
 
 
template <typename dtype= "" >
class RMSPropSolver : public SGDSolver<dtype> { // 模板类RMSPropSolver,继承SGDSolver
  public :
// 显示构造函数,调用constructor_sanity_check函数
   explicit RMSPropSolver( const SolverParameter& param) : SGDSolver<dtype>(param) { constructor_sanity_check(); }
   explicit RMSPropSolver( const string& param_file) : SGDSolver<dtype>(param_file) { constructor_sanity_check(); }
 
  protected :
// 计算并更新相应Blob值
   virtual void ComputeUpdateValue( int param_id, Dtype rate);
   void constructor_sanity_check() {
     CHECK_EQ( 0 , this ->param_.momentum())
         << "Momentum cannot be used with RMSProp." ;
     CHECK_GE( this ->param_.rms_decay(), 0 )
         << "rms_decay should lie between 0 and 1." ;
     CHECK_LT( this ->param_.rms_decay(), 1 )
         << "rms_decay should lie between 0 and 1." ;
   }
 
// 禁止使用RMSPropSolver类的拷贝和赋值操作
   DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
};
 
template <typename dtype= "" >
class AdaDeltaSolver : public SGDSolver<dtype> { // 模板类AdaDeltaSolver,继承SGDSolver
  public :
// 显示构造函数,调用AdaDeltaPreSolve函数
   explicit AdaDeltaSolver( const SolverParameter& param) : SGDSolver<dtype>(param) { AdaDeltaPreSolve(); }
   explicit AdaDeltaSolver( const string& param_file) : SGDSolver<dtype>(param_file) { AdaDeltaPreSolve(); }
 
  protected :
   void AdaDeltaPreSolve();
// 计算并更新相应Blob值
   virtual void ComputeUpdateValue( int param_id, Dtype rate);
 
// 禁止使用AdaDeltaSolver类的拷贝和赋值操作
   DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
};
 
/**
  * @brief AdamSolver, an algorithm for first-order gradient-based optimization
  *        of stochastic objective functions, based on adaptive estimates of
  *        lower-order moments. Described in [1].
  *
  * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization."
  *     arXiv preprint arXiv:1412.6980v8 (2014).
  */
template <typename dtype= "" >
class AdamSolver : public SGDSolver<dtype> { // 模板类AdamSolver,继承SGDSolver
  public :
// 显示构造函数,调用AdamPreSolve函数
   explicit AdamSolver( const SolverParameter& param) : SGDSolver<dtype>(param) { AdamPreSolve();}
   explicit AdamSolver( const string& param_file) : SGDSolver<dtype>(param_file) { AdamPreSolve(); }
 
  protected :
   void AdamPreSolve();
// 计算并更新相应Blob值
   virtual void ComputeUpdateValue( int param_id, Dtype rate);
 
// 禁止使用AdamSolver类的拷贝和赋值操作
   DISABLE_COPY_AND_ASSIGN(AdamSolver);
};
 
// new一个指定的solver方法对象
template <typename dtype= "" >
Solver<dtype>* GetSolver( const SolverParameter& param) {
   SolverParameter_SolverType type = param.solver_type();
 
   switch (type) {
   case SolverParameter_SolverType_SGD:
       return new SGDSolver<dtype>(param);
   case SolverParameter_SolverType_NESTEROV:
       return new NesterovSolver<dtype>(param);
   case SolverParameter_SolverType_ADAGRAD:
       return new AdaGradSolver<dtype>(param);
   case SolverParameter_SolverType_RMSPROP:
       return new RMSPropSolver<dtype>(param);
   case SolverParameter_SolverType_ADADELTA:
       return new AdaDeltaSolver<dtype>(param);
   case SolverParameter_SolverType_ADAM:
       return new AdamSolver<dtype>(param);
   default :
       LOG(FATAL) << "Unknown SolverType: " << type;
   }
   return (Solver<dtype>*) NULL;
}
 
// namespace caffe
 
#endif  // CAFFE_OPTIMIZATION_SOLVER_HPP_</dtype></dtype></dtype></dtype></dtype></dtype></dtype></dtype></typename></dtype></dtype></dtype></typename></dtype></dtype></dtype></typename></dtype></dtype></dtype></typename></dtype></dtype></dtype></typename></dtype></dtype></dtype></typename></shared_ptr<blob<dtype></shared_ptr<blob<dtype></dtype></dtype></dtype></typename></dtype></dtype></dtype></typename></callback*></shared_ptr<net<dtype></net<dtype></callback*></typename></shared_ptr<net<dtype></net<dtype></typename></vector></string>
在caffe.proto文件中,主要有一个message是与solver相关的,如下:
?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 40 (last added: momentum2)
message SolverParameter { // Solver参数
   //
   // Specifying the train and test networks
   //
   // Exactly one train net must be specified using one of the following fields:
   //     train_net_param, train_net, net_param, net
   // One or more test nets may be specified using any of the following fields:
   //     test_net_param, test_net, net_param, net
   // If more than one test net field is specified (e.g., both net and
   // test_net are specified), they will be evaluated in the field order given
   // above: (1) test_net_param, (2) test_net, (3) net_param/net.
   // A test_iter must be specified for each test_net.
   // A test_level and/or a test_stage may also be specified for each test_net.
   //
 
   // Proto filename for the train net, possibly combined with one or more test nets.
   optional string net = 24 ; // .prototxt文件名, train or test net
   // Inline train net param, possibly combined with one or more test nets.
   optional NetParameter net_param = 25 ; // net parameter类
 
   optional string train_net = 1 ; // Proto filename for the train net, .prototxt文件名,train net
   repeated string test_net = 2 ; // Proto filenames for the test nets, .prototxt文件名,test net
   optional NetParameter train_net_param = 21 ; // Inline train net params, train net parameter类
   repeated NetParameter test_net_param = 22 ; // Inline test net params, test net parameter类
 
   // The states for the train/test nets. Must be unspecified or
   // specified once per net.
   //
   // By default, all states will have solver = true;
   // train_state will have phase = TRAIN,
   // and all test_state's will have phase = TEST.
   // Other defaults are set according to the NetState defaults.
   optional NetState train_state = 26 ; // train net state
   repeated NetState test_state = 27 ; // test net state
 
   // The number of iterations for each test net.
   repeated int32 test_iter = 3 ; // 对于测试网络(用于评估)执行一次需要迭代的次数, test_iter * batch_size = 测试图像总数量
 
   // The number of iterations between two testing phases.
   optional int32 test_interval = 4 [ default = 0 ];  // 指定执行多少次训练网络执行一次测试网络
   optional bool test_compute_loss = 19 [ default = false ]; // 执行测试网络时是否计算loss
   // If true, run an initial test pass before the first iteration,
   // ensuring memory availability and printing the starting value of the loss.
   optional bool test_initialization = 32 [ default = true ]; // 在总的开始前,是否先执行一次测试网络
   optional float base_lr = 5 ; // The base learning rate,基础学习率
   // the number of iterations between displaying info. If display = 0, no info
   // will be displayed.
   optional int32 display = 6 ; // 指定迭代多少次显示一次结果信息
   // Display the loss averaged over the last average_loss iterations
   optional int32 average_loss = 33 [ default = 1 ]; //
   optional int32 max_iter = 7 ; // the maximum number of iterations
   // accumulate gradients over `iter_size` x `batch_size` instances
   optional int32 iter_size = 36 [ default = 1 ]; //
 
   // The learning rate decay policy. The currently implemented learning rate
   // policies are as follows: // 学习率衰减策略
   //    - fixed: always return base_lr.
   //    - step: return base_lr * gamma ^ (floor(iter / step))
   //    - exp: return base_lr * gamma ^ iter
   //    - inv: return base_lr * (1 + gamma * iter) ^ (- power)
   //    - multistep: similar to step but it allows non uniform steps defined by
   //      stepvalue
   //    - poly: the effective learning rate follows a polynomial decay, to be
   //      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
   //    - sigmoid: the effective learning rate follows a sigmod decay
   //      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
   //
   // where base_lr, max_iter, gamma, step, stepvalue and power are defined
   // in the solver parameter protocol buffer, and iter is the current iteration.
   optional string lr_policy = 8 ; // 学习策略,可取的值包括:fixed、step、exp、inv、multistep、poly、sigmoid
   optional float gamma = 9 ; // The parameter to compute the learning rate.
   optional float power = 10 ; // The parameter to compute the learning rate.
   optional float momentum = 11 ; // The momentum value, 动量
   optional float weight_decay = 12 ; // The weight decay. //
   // regularization types supported: L1 and L2
   // controlled by weight_decay
   optional string regularization_type = 29 [ default = "L2" ]; // L1 or L2
   // the stepsize for learning rate policy "step"
   optional int32 stepsize = 13 ; //
   // the stepsize for learning rate policy "multistep"
   repeated int32 stepvalue = 34 ; //
 
   // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
   // whenever their actual L2 norm is larger.
   optional float clip_gradients = 35 [ default = - 1 ]; //
 
   optional int32 snapshot = 14 [ default = 0 ]; // The snapshot interval, 迭代多少次保存下结果(如权值、偏置)
   optional string snapshot_prefix = 15 ; // The prefix for the snapshot,指定保存文件名的前缀
   // whether to snapshot diff in the results or not. Snapshotting diff will help
   // debugging but the final protocol buffer size will be much larger.
   optional bool snapshot_diff = 16 [ default = false ]; //
   enum SnapshotFormat {
     HDF5 = 0 ;
     BINARYPROTO = 1 ;
   }
   optional SnapshotFormat snapshot_format = 37 [ default = BINARYPROTO]; // HDF5 or BINARYPROTO
   // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.
   enum SolverMode {
     CPU = 0 ;
     GPU = 1 ;
   }
   optional SolverMode solver_mode = 17 [ default = GPU]; // 指定solve mode是CPU还是GPU
   // the device_id will that be used in GPU mode. Use device_id = 0 in default.
   optional int32 device_id = 18 [ default = 0 ]; // GPU mode下使用
   // If non-negative, the seed with which the Solver will initialize the Caffe
   // random number generator -- useful for reproducible results. Otherwise,
   // (and by default) initialize using a seed derived from the system clock.
   optional int64 random_seed = 20 [ default = - 1 ]; //
 
   // Solver type
   enum SolverType { // solver优化方法
     SGD = 0 ;
     NESTEROV = 1 ;
     ADAGRAD = 2 ;
     RMSPROP = 3 ;
     ADADELTA = 4 ;
     ADAM = 5 ;
   }
   optional SolverType solver_type = 30 [ default = SGD]; // 指定solver优化方法
   // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
   optional float delta = 31 [ default = 1e- 8 ]; //
   // parameters for the Adam solver
   optional float momentum2 = 39 [ default = 0.999 ]; //
 
   // RMSProp decay value
   // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
   optional float rms_decay = 38 ; //
 
   // If true, print information about the state of the net that may help with
   // debugging learning problems.
   optional bool debug_info = 23 [ default = false ]; //
 
   // If false, don't save a snapshot after training finishes.
   optional bool snapshot_after_train = 28 [ default = true ]; //
}
solver的测试代码如下:
?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include "funset.hpp"
#include <string>
#include <vector>
#include<map>
#include "common.hpp"
 
int test_caffe_solver()
{
     caffe::Caffe::set_mode(caffe::Caffe::CPU); // set run caffe mode
 
     const std::string solver_prototxt{ "E:/GitCode/Caffe_Test/test_data/model/mnist/lenet_solver.prototxt" };
 
     caffe::SolverParameter solver_param;
     if (!caffe::ReadProtoFromTextFile(solver_prototxt.c_str(), &solver_param)) {
         fprintf(stderr, "parse solver.prototxt fail\n" );
         return - 1 ;
     }
 
     boost::shared_ptr<caffe::solver< float > > solver(caffe::GetSolver< float >(solver_param));
 
     caffe::SolverParameter param = solver->param();
 
     if (param.has_net())
         fprintf(stderr, "net: %s\n" , param.net().c_str());
     if (param.has_net_param()) {
         fprintf(stderr, "has net param\n" );
         caffe::NetParameter net_param = param.net_param();
         if (net_param.has_name())
             fprintf(stderr, "net param name: %s\n" , net_param.name().c_str());
     }
     if (param.has_train_net())
         fprintf(stderr, "train_net: %s\n" , param.train_net());
     if (param.test_net().size() > 0 ) {
         for (auto test_net : param.test_net())
             fprintf(stderr, "test_net: %s\n" , test_net);
     }
     if (param.has_train_net_param()) {
         fprintf(stderr, "has train net param\n" );
         caffe::NetParameter train_net_param = param.train_net_param();
     }
     if (param.test_net_param().size() > 0 ) {
         fprintf(stderr, "has test net param\n" );
         std::vector<caffe::netparameter> test_net_param;
         for (auto net_param : param.test_net_param())
             test_net_param.push_back(net_param);
     }
 
     if (param.has_train_state()) {
         fprintf(stderr, "has train state\n" );
         caffe::NetState state = param.train_state();
     }
     if (param.test_state().size()) {
         fprintf(stderr, "has test state\n" );
     }
 
     if (param.test_iter_size() > 0 ) {
         fprintf(stderr, "has test iter\n" );
         for (auto iter : param.test_iter())
             fprintf(stderr, "  %d  " , iter);
         fprintf(stderr, "\n" );
     }
 
     if (param.has_test_interval())
         fprintf(stderr, "test interval: %d\n" , param.test_interval());
     bool test_compute_loss = param.test_compute_loss();
     fprintf(stderr, "test compute loss: %d\n" , test_compute_loss);
     bool test_initialization = param.test_initialization();
     fprintf(stderr, "test initializtion: %d\n" , test_initialization);
     if (param.has_base_lr()) {
         float base_lr = param.base_lr();
         fprintf(stderr, "base lr: %f\n" , base_lr);
     }
     if (param.has_display()) {
         int display = param.display();
         fprintf(stderr, "display: %d\n" , display);
     }
     int average_loss = param.average_loss();
     fprintf(stderr, "average loss: %d\n" , average_loss);
     if (param.has_max_iter()) {
         int max_iter = param.max_iter();
         fprintf(stderr, "max iter: %d\n" , max_iter);
     }
     int iter_size = param.iter_size();
     fprintf(stderr, "iter size: %d\n" , iter_size);
 
     if (param.has_lr_policy())
         fprintf(stderr, "lr policy: %s\n" , param.lr_policy().c_str());
     if (param.has_gamma())
         fprintf(stderr, "gamma: %f\n" , param.gamma());
     if (param.has_power())
         fprintf(stderr, "power: %f\n" , param.power());
     if (param.has_momentum())
         fprintf(stderr, "momentum: %f\n" , param.momentum());
     if (param.has_weight_decay())
         fprintf(stderr, "weight decay: %f\n" , param.weight_decay());
     std::string regularization_type = param.regularization_type();
     fprintf(stderr, "regularization type: %s\n" , param.regularization_type().c_str());
     if (param.has_stepsize())
         fprintf(stderr, "stepsize: %d\n" , param.stepsize());
     if (param.stepvalue_size() > 0 ) {
         fprintf(stderr, "has stepvalue\n" );
         for (auto value : param.stepvalue())
             fprintf(stderr, "  %d  " , value);
         fprintf(stderr, "\n" );
     }
 
     fprintf(stderr, "clip gradients: %f\n" , param.clip_gradients());
 
     fprintf(stderr, "snapshot: %d\n" , param.snapshot());
     if (param.has_snapshot_prefix())
         fprintf(stderr, "snapshot prefix: %s\n" , param.snapshot_prefix().c_str());
     fprintf(stderr, "snapshot diff: %d\n" , param.snapshot_diff());
     caffe::SolverParameter_SnapshotFormat snapshot_format = param.snapshot_format();
     fprintf(stderr, "snapshot format: %s\n" , snapshot_format == 0 ? "HDF5" : "BINARYPROTO" );
     caffe::SolverParameter_SolverMode solver_mode = param.solver_mode();
     fprintf(stderr, "solver mode: %s\n" , solver_mode == 0 ? "CPU" : "GPU" );
     if (param.has_device_id())
         fprintf(stderr, "device id: %d\n" , param.device_id());
     fprintf(stderr, "random seed: %d\n" , param.random_seed());
 
     caffe::SolverParameter_SolverType solver_type = param.solver_type();
     std::string solver_method[] { "SGD" , "NESTEROV" , "ADAGRAD" , "RMSPROP" , "ADADELTA" , "ADAM" };
     fprintf(stderr, "solver type: %s\n" , solver_method[solver_type].c_str());
     fprintf(stderr, "delta: %f\n" , param.delta());
     fprintf(stderr, "momentum2: %f\n" , param.momentum2());
 
     if (param.has_rms_decay())
         fprintf(stderr, "rms decy: %f\n" , param.rms_decay());
 
     fprintf(stderr, "debug info: %d\n" , param.debug_info());
     fprintf(stderr, "snapshot after train: %d\n" , param.snapshot_after_train());
 
     boost::shared_ptr<caffe::net< float >> net = solver->net();
     std::vector<boost::shared_ptr<caffe::net< float >>> test_nets = solver->test_nets();
     fprintf(stderr, "test nets size: %d\n" , test_nets.size());
     fprintf(stderr, "iter: %d\n" , solver->iter());
 
     return 0 ;
}</boost::shared_ptr<caffe::net< float ></caffe::net< float ></caffe::netparameter></ float ></caffe::solver< float ></map></vector></string>
部分输出结果如下:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值