Caffe中Solver解析

1.Solver的初始化

shared_ptr<caffe::Solver<float>>
    solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

caffe.cpp中的train函数中通过上述的代码定义了一个指向Solver<float>的shared_ptr。其中主要是通过调用SolverRegistry这个类的静态成员函数CreateSolver得到一个指向Solver的指针来构造shared_ptr类型的solver。而且由于C++多态的特性,solver是一个指向基类Solver类型的指针,通过solver这个智能指针来调用各个成员函数会调用到各个子类(SGDSolver等)的函数。
具体步骤:
(1)SolverRegistry::CreateSolver(solver_param)
(2)通过static的g_registry_[type]获取type对应的Solver的Creator函数指针。
(3)调用Creator函数。
(4)new SGDSolver<Dtype>(solver_param)创建solver。
SolverRegistry类源码:

class SolverRegistry 
{
   public:
   typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
   typedef std::map<string, Creator> CreatorRegistry;
   static CreatorRegistry& Registry() 
   {
     static CreatorRegistry* g_registry_ = new CreatorRegistry();
     return *g_registry_;
   }
   static void AddCreator(const string& type, Creator creator) 
   {
     CreatorRegistry& registry = Registry();
     CHECK_EQ(registry.count(type), 0)
         << "Solver type " << type << " already registered.";
     registry[type] = creator;
   }
   static Solver<Dtype>* CreateSolver(const SolverParameter& param) 
   {
     const string& type = param.type();
     CreatorRegistry& registry = Registry();
     CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
         << " (known types: " << SolverTypeListString() << ")";
     return registry[type](param);
   }
   static vector<string> SolverTypeList() 
   {
     CreatorRegistry& registry = Registry();
     vector<string> solver_types;
     for (typename CreatorRegistry::iterator iter = registry.begin();iter != registry.end(); ++iter) 
     {
       solver_types.push_back(iter->first);
     }
     return solver_types;
   }
  private:
   SolverRegistry() {}
   static string SolverTypeListString() 
   {
     vector<string> solver_types = SolverTypeList();
     string solver_types_str;
     for (vector<string>::iterator iter = solver_types.begin();iter != solver_types.end(); ++iter) 
     {
       if (iter != solver_types.begin()) 
       {
         solver_types_str += ", ";
       }
       solver_types_str += *iter;
     }
     return solver_types_str;
   }
};

SolverRegistry类的构造函数是private的,也就是用我们没有办法去构造一个这个类的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。
CreateSolver函数先定义了string类型的变量type,表示Solver的类型,然后定义了一个key类型为string,value类型为Creator的map,变量名为registry,其中Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver<Dtype>*。如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver<Dtype>*返回。
Registry函数中定义了一个static的变量g_registry,这个变量是一个指向CreatorRegistry这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,而且在其他地方修改这个map里的内容,。事实上各个Solver的register的过程正是向g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。
Register的具体步骤:
(1)Registry_Solver_Class(SGD)。
(2)定义Creator函数,Registry_Solver_Creator。
(3)定义SolverRegistry<float>类型的static变量,定义SolverRegistry<double>类型的static变量。
(4)SolverRegistry::AddCreator将定义的Creator函数指针添加到static的变量g_registry_(map)中。

SolverRegisterer源码:

template <typename Dtype>
class SolverRegisterer {
 public:
  SolverRegisterer(const string& type,
                   Solver<Dtype>* (*creator)(const SolverParameter&));
};

#define REGISTER_SOLVER_CREATOR(type, creator)                                 \
  static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>);    \
  static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>)   \

#define REGISTER_SOLVER_CLASS(type)                                            \
  template <typename Dtype>                                                    \
  Solver<Dtype>* Creator_##type##Solver(                                       \
      const SolverParameter& param)                                            \
  {                                                                            \
    return new type##Solver<Dtype>(param);                                     \
  }                                                                            \
  REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)

}
#endif 

在sgd_solver.cpp文件末尾有REGISTER_SOLVER_CL

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值