任务
上一章中我们是用以下代码来进行对象创建的。这段代码并没有什么问题,每次新增一个层的时候在此处添加els if即可。不过还是可以有更优雅一些的实现。
LayerBase *LayerFactory(string classname)
{
if (classname == "conv")
return new LayerConv();
else if (classname == "pooling")
return new LayerPooling();
else if (classname == "softmax")
return new LayerSoftmax();
else
return nullptr;
}
思路
计划构造一张map映射表,key为类名或者ID,value为可返回对象的函数指针。这样上面的函数就可以变成这一个样子。
// LayerFactory.h
#include<functional>
using std::function;
class LayerFactory
{
public:
static LayerBase *Create(string str){
function <LayerBase*()> fun = s_CreateMap[str];
if (nullptr == fun)
return nullptr;
return fun();
}
static void RegisterCreater(string classname, function<LayerBase*()> creater)
{
s_CreateMap[classname] = creater;
}
private:
static map<string, function<LayerBase*()>> s_CreateMap;
};
// LayerFactory.cpp
#include"LayerFactory.h"
LayerBase *LayerFactoryMap(string classname)
{
return LayerFactory::Create(classname);
}
s_CreateMap
就是这个映射表,最直接的想法是把每个类的构造函数赋予s_CreateMap
,可是C++是无法获取指向类的构造函数的函数指针。
所以转变思路,用另一个函数(比如叫fwrap
)把构造函数封起来,然后再把fwrap
的地址赋予map表即可。如果在每个layer类里面定义这样的fwrap
函数。虽然成员函数可以获得函数指针,但其在调用的时必须要有具体对象(因为参数列表中需要传this指针),这样调用就显得麻烦了。所以最后就只有全局函数或者类的静态函数符合要求。
以Conv层举例,得到以下实现:
class AUTO_FACTORY_Conv
{
public:
AUTO_FACTORY_Conv(){ LayerFactory::RegisterCreater("conv", AUTO_FACTORY_Conv::CreateLayer); };
static LayerBase* CreateLayer() { return new LayerConv(); }
};
可以看到这里把注册语句放到了AUTO_FACTORY_Conv 类的构造函数中。这是为了只要声明一个对象即可自动去注册一个字符串和对象构建器的映射关系。
比如:
static AUTO_FACTORY_Conv g_obj_for_register_conv;
那CreateLayer
函数就是上面所描述的fwrap
了。
可以发现这段代码非常模式化,按照c++惯例用宏来总结下。
#define REGISTER_LAYER_CREATE(idname,classname) \
class AUTO_FACTORY_##idname \
{ \
public: \
AUTO_FACTORY_##idname(){ LayerFactory::RegisterCreater(#idname, AUTO_FACTORY_##idname::CreateLayer); }; \
static LayerBase* CreateLayer() { return new classname(); } \
}; \
static AUTO_FACTORY_##idname class_creater_register_##idname;
最终完整代码如下:
// LayerFactory.h
#include<functional>
using std::function;
LayerBase *LayerFactory(string classname);
class LayerFactory
{
public:
static LayerBase *Create(string str){
function <LayerBase*()> fun = s_CreateMap[str];
if (nullptr == fun)
return nullptr;
return fun();
}
static void RegisterCreater(string classname, function<LayerBase*()> creater)
{
s_CreateMap[classname] = creater;
}
private:
static map<string, function<LayerBase*()>> s_CreateMap;
};
#define REGISTER_LAYER_CREATE(idname,classname) \
class AUTO_FACTORY_##idname \
{ \
public: \
AUTO_FACTORY_##idname(){ LayerFactory::RegisterCreater(#idname, AUTO_FACTORY_##idname::CreateLayer); }; \
static LayerBase* CreateLayer() { return new classname(); } \
}; \
static AUTO_FACTORY_##idname class_creater_register_##idname;
// LayerFactory.cpp
#include"layerFactory.h"
map<string, function<LayerBase*()>> LayerFactory::s_CreateMap;
// register
REGISTER_LAYER_CREATE(conv,LayerConv)
REGISTER_LAYER_CREATE(pooling, LayerPooling)
REGISTER_LAYER_CREATE(softmax, LayerSoftmax)
LayerBase *LayerFactoryMap(string classname)
{
return LayerFactory::Create(classname);
}
小结
- 饶了一圈,代码量似乎没有减小,之前的if else不挺好的吗。考虑到更复杂的需求时才能体现他的好处,比如我们有自动生成代码的需求(根据配置文件,批量生成重复代码)。那么插一句宏会比修改函数体内的逻辑来的更方便。
REGISTER_LAYER_CREATE
宏最终应用在layerFactory.cpp中。即用于注册的类声明只有这个cpp文件可见,且定义的帮助注册的类也是置为static的,这些都是为了把对象构建器的细节封装到这个编译单元内。- 根据工厂模式的定义,属于简单工厂(甚至不属于设计模式,就是应用了封装的思想)。所以设计模式不是用的越多越复杂越好,适合业务的才是最好的。
注意s_CreateMap不能设置成全局变量,因为在不同的编译单元内,全局变量的初始化顺序是未定义的。比如在某个.cpp中向全局的s_CreateMap进行设置,但是s_CreateMap此时不能保证已经被初始化完成了。一个更大的原则是不要定义全局的class对象,如果确实需要则应该把其封装到一个类中。
referecne
1.工厂模式
2.各类c++框架的“注册”机制基本都是此类实现