SolverAction
最开始定义了一个枚举
/**
* @brief Enumeration of actions that a client of the Solver may request by
* implementing the Solver's action request function, which a
* client may optionally provide in order to request early termination
* or saving a snapshot without exiting. In the executable caffe, this
* mechanism is used to allow the snapshot to be saved when stopping
* execution with a SIGINT (Ctrl-C).
*/
namespace SolverAction {
enum Enum {
NONE = 0, // Take no special action.
STOP = 1, // Stop training. snapshot_after_train controls whether a
// snapshot is created.
SNAPSHOT = 2 // Take a snapshot, and keep training.
};
}
枚举了一些动作,有时候我们的客户端可能会提前终止程序,例如我们使用 ctrl+c
把程序终止了,他就需要做出一些响应,例如生成快照,快速保存等等。那么就需要使用这几个状态来判断
- None就是没有什么异常
- STOP就是训练停止了
- SNAPSHOT就是创建一个SNAPSHOT然后我们接着去训练
ActionCallback 类型
这里定义了一个新的类型,给需要用到那几个状态的回调函数使用
/**
* @brief Type of a function that returns a Solver Action enumeration.
*/
typedef boost::function<SolverAction::Enum()> ActionCallback;
用来在出了任何异常的时候调用相应状态的应急措施,看它的参数就是上面定义的那个枚举的状态码。
Solver类
/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
template <typename Dtype>
class Solver
注释上写了这个类是个优化网络的一个借口,需要自己去实现 ApplyUpdate
这个函数来实现具体的参数更新的过程。
Solver 构造函数
explicit Solver(const SolverParameter& param);
explicit Solver(const string& param_file);
他也是两个构造函数,一个是从参数读取,一个是从文件读取。
template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file)
: net_(), callbacks_(), requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, ¶m);
Init(param);
}
可以看到,这个函数就是从文件里去读,然后放在 SolverParameter
类型的变量里,最后调用了 Init
这个函数。
template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param)
: net_(), callbacks_(), requested_early_exit_(false) {
Init(param);
}
如果直接传的就是这个类型的对象的话那么他就直接调用 Init
这个函数了,所以其实这两个函数实现的功能是一样的。
接下来我们重点来看这个 Init
函数。
Init 函数
void Init(const SolverParameter& param);
这个函数根据读取的参数来初始化网络
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
CheckSnapshotWritePermissions();
if (param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed() + Ca