1.Solver的初始化
shared_ptr<caffe::Solver<float>>
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
caffe.cpp中的train函数中通过上述的代码定义了一个指向Solver<float>的shared_ptr。其中主要是通过调用SolverRegistry这个类的静态成员函数CreateSolver得到一个指向Solver的指针来构造shared_ptr类型的solver。而且由于C++多态的特性,solver是一个指向基类Solver类型的指针,通过solver这个智能指针来调用各个成员函数会调用到各个子类(SGDSolver等)的函数。
具体步骤:
(1)SolverRegistry::CreateSolver(solver_param)。
(2)通过static的g_registry_[type]获取type对应的Solver的Creator函数指针。
(3)调用Creator函数。
(4)new SGDSolver<Dtype>(solver_param)创建solver。
SolverRegistry类源码:
class SolverRegistry
{
public:
typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
typedef std::map<string, Creator> CreatorRegistry;
static CreatorRegistry& Registry()
{
static CreatorRegistry* g_registry_ = new CreatorRegistry();
return *g_registry_;
}
static void AddCreator(const string& type, Creator creator)
{
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 0)
<< "Solver type " << type << " already registered.";
registry[type] = creator;
}
static Solver<Dtype>* CreateSolver(const SolverParameter& param)
{
const string& type = param.type();
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
<< " (known types: " << SolverTypeListString() << ")";
return registry[type](param);
}
static vector<string> SolverTypeList()
{
CreatorRegistry& registry = Registry();
vector<string> solver_types;
for (typename CreatorRegistry::iterator iter = registry.begin();iter != registry.end(); ++iter)
{
solver_types.push_back(iter->first);
}
return solver_types;
}
private:
SolverRegistry() {}
static string SolverTypeListString()
{
vector<string> solver_types = SolverTypeList();
string solver_types_str;
for (vector<string>::iterator iter = solver_types.begin();iter != solver_types.end(); ++iter)
{
if (iter != solver_types.begin())
{
solver_types_str += ", ";
}
solver_types_str += *iter;
}
return solver_types_str;
}
};
SolverRegistry类的构造函数是private的,也就是用我们没有办法去构造一个这个类的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。
CreateSolver函数先定义了string类型的变量type,表示Solver的类型,然后定义了一个key类型为string,value类型为Creator的map,变量名为registry,其中Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver<Dtype>*。如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver<Dtype>*返回。
Registry函数中定义了一个static的变量g_registry,这个变量是一个指向CreatorRegistry这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,而且在其他地方修改这个map里的内容,。事实上各个Solver的register的过程正是向g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。
Register的具体步骤:
(1)Registry_Solver_Class(SGD)。
(2)定义Creator函数,Registry_Solver_Creator。
(3)定义SolverRegistry<float>类型的static变量,定义SolverRegistry<double>类型的static变量。
(4)SolverRegistry::AddCreator将定义的Creator函数指针添加到static的变量g_registry_(map)中。
SolverRegisterer源码:
template <typename Dtype>
class SolverRegisterer {
public:
SolverRegisterer(const string& type,
Solver<Dtype>* (*creator)(const SolverParameter&));
};
#define REGISTER_SOLVER_CREATOR(type, creator) \
static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
#define REGISTER_SOLVER_CLASS(type) \
template <typename Dtype> \
Solver<Dtype>* Creator_##type##Solver( \
const SolverParameter& param) \
{ \
return new type##Solver<Dtype>(param); \
} \
REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
}
#endif
在sgd_solver.cpp文件末尾有REGISTER_SOLVER_CL

最低0.47元/天 解锁文章
2238

被折叠的 条评论
为什么被折叠?



