caffe工厂类solver_factory根据solver.prototxt中定义的type类型创建各种solver实例
创建代码如下:
caffe.cpp
shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
实现在solver_factory.hpp:
/**
* @brief A solver factory that allows one to register solvers, similar to
* layer factory. During runtime, registered solvers could be called by passing
* a SolverParameter protobuffer to the CreateSolver function:
*
* SolverRegistry<Dtype>::CreateSolver(param);
*
* There are two ways to register a solver. Assuming that we have a solver like:
*
* template <typename Dtype>
* class MyAwesomeSolver : public Solver<Dtype> {
* // your implementations
* };
*
* and its type is its C++ class name, but without the "Solver" at the end
* ("MyAwesomeSolver" -> "MyAwesome").
*
* If the solver is going to be created simply by its constructor, in your C++
* file, add the following line:
*
* REGISTER_SOLVER_CLASS(MyAwesome);
*
* Or, if the solver is going to be created by another creator function, in the
* format of:
*
* template <typename Dtype>
* Solver<Dtype*> GetMyAwesomeSolver(const SolverParameter& param) {
* // your implementation
* }
*
* then you can register the creator function instead, like
*
* REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver)
*
* Note that each solver type should only be registered once.
*/
#ifndef CAFFE_SOLVER_FACTORY_H_
#define CAFFE_SOLVER_FACTORY_H_
#include <map>
#include <string>
#include <vector>
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
namespace caffe {
template <typename Dtype>
class Solver;
template <typename Dtype>
class SolverRegistry {
public:
typedef Solver<Dtype>* (*Creator)(const SolverParameter&);// creator是个函数指针,此函数返回solver实例,creator指向不同类型的solver子类会创建不同类型的solver对象。
typedef std::map<string, Creator> CreatorRegistry;// 注册器,key是solver的type,value是函数指针creator
static CreatorRegistry& Registry(){ // Registry()声明为静态成员函数,这样创建的g_registry_变量代替了声明为全局或成员变量,而只需通过调用Registry()即可访问。
static CreatorRegistry* g_registry_ = new CreatorRegistry();// 只创建一次,之后存储在静态区。
return *g_registry_;
}
// Adds a creator.
static void AddCreator(const string& type, Creator creator) {// 为g_registry_注册一个creator,creator用于实例化solver对象。
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 0)
<< "Solver type " << type << " already registered.";
registry[type] = creator;
}
// Get a solver using a SolverParameter.
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
const string& type = param.type();
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
<< " (known types: " << SolverTypeListString() << ")";
return registry[type](param);// 根据type找到注册器g_registry_中的
}
static vector<string> SolverTypeList() {
CreatorRegistry& registry = Registry();
vector<string> solver_types;
for (typename CreatorRegistry::iterator iter = registry.begin();
iter != registry.end(); ++iter) {
solver_types.push_back(iter->first);
}
return solver_types;
}
private:
// Solver registry should never be instantiated - everything is done with its
// static variables.
SolverRegistry() {}
static string SolverTypeListString() {
vector<string> solver_types = SolverTypeList();
string solver_types_str;
for (vector<string>::iterator iter = solver_types.begin();
iter != solver_types.end(); ++iter) {
if (iter != solver_types.begin()) {
solver_types_str += ", ";
}
solver_types_str += *iter;
}
return solver_types_str;
}
};
template <typename Dtype>
class SolverRegisterer {
public:
SolverRegisterer(const string& type,
Solver<Dtype>* (*creator)(const SolverParameter&)) {
// LOG(INFO) << "Registering solver type: " << type;
SolverRegistry<Dtype>::AddCreator(type, creator);
}
};
#define REGISTER_SOLVER_CREATOR(type, creator) \
static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
#define REGISTER_SOLVER_CLASS(type) \
template <typename Dtype> \
Solver<Dtype>* Creator_##type##Solver( \
const SolverParameter& param) \
{ \
return new type##Solver<Dtype>(param); \
} \
REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
} // namespace caffe
#endif // CAFFE_SOLVER_FACTORY_H_