典型的使用场景:
当把Onnx 模型解析到自己的框架时,需要遍历onnx模型所有的node,根据它的名字,例如”conv","relu“等关键字,创建相应的自己框架里的算子。
简单工厂模式
如何在自己的源代码里写这样的方法呢?
最原始的是写一堆if,如果node.op_type == “conv”,则创建自己框架的conv, 如果node.op_type == “relu”,则创建自己框架的relu…
把这些逻辑包装到一个工厂类,然后开放出一个create(std::string op_type)接口,就可以自动根据读取到的onnx node optype来创建自己框架相应的算子。这便是简单工厂模式。
但简单工厂模式有个问题,就是在工厂类里要写好key->value的对应关系,一旦出现onnx模型里新增了一个算子(比如”gelu"),用户就不得不修改源代码里的工厂类才能支持该算子。这就不可接受了,因为有时不想把这部分源代码开放出来,更不想要用户这么麻烦去改源代码。
注册工厂模式
此时就可以用到注册工厂模式。用户提供了新算子类的cpp和h文件后,通过一个开放的文件,把约定的名字(例如”gelu“)把这个算子注册到工厂里去,这样源代码在解析node.op_type == "gelu"的时候,就能自动认识这个算子了。
具体的做法是,工厂类里不要保存具体类,这样就避免了新类出现后要改工厂类。但这样一来工厂类该保存什么呢?可以用字典(c++里则是map)保存一个注册类对象,当需要创建具体对象的时候,通过注册类去创建。用字典的原因是,可以根据外界读取的key,自动取出相应的注册类。
接下来需要说明,这个注册类不得不需要一个基类和子类。工厂类里的字典要存放string -> 注册类基类的键值对,这样才能统一适应各种不同的算子。而具体的注册类对象必须得是一个模板,以满足各类算子创建的需求,包括创建用户提供的新类的需求。
注册类里,必须有一个创建函数,用来创建一个具体算子类对象(通常是在堆上创建一个对象,然后返回指针)。这样,当工厂类要创建该对象的时候,就可以通过字典查找到这个注册类,从而调用这个类的创建函数来创建实例。
这样一来,不仅源代码里创建对象的接口完全统一了,被创建的类本身和工厂类也完成了隔离。程序一开始,会进行一个注册动作,把一个文件夹里所有的类都注册到工厂类里的字典,然后到了创建对象的时候,工厂类自动根据读取到的key值,自动创造出不同类的对象。用户要新增加类的时候,只需要提供新类的具体实施,并完成注册(也就是约定它对应的key是什么),和源代码里的工厂类彻底隔离开来。
#include<iostream>
#include<cstring>
#include<map>
class Op {
public:
virtual void Show() = 0;
virtual ~Op() {}
};
class Conv : public Op {
public:
void Show() override {
std::cout << "我是conv" << std::endl;
}
};
class BatchNorm : public Op {
public:
void Show() override {
std::cout << "我是BatchNorm" << std::endl;
}
};
// 产品注册接口类
class IProductRegistrar {
public:
virtual Op *CreateProduct() = 0;
protected:
// 禁止外部构造和虚构, 子类的"内部"的其他函数可以调用
IProductRegistrar() {}
virtual ~IProductRegistrar() {}
private:
// 禁止外部拷贝和赋值操作
IProductRegistrar(const IProductRegistrar &);
const IProductRegistrar &operator=(const IProductRegistrar &);
};
// 工厂类
class ProductFactory {
public:
static ProductFactory &Instance() {
static ProductFactory instance;
return instance;
}
void RegisterProduct(IProductRegistrar *registrar, std::string name) {
m_ProductRegistry[name] = registrar;
}
Op *GetProduct(std::string name) {
if (m_ProductRegistry.find(name) != m_ProductRegistry.end()) {
return m_ProductRegistry[name]->CreateProduct();
}
std::cout << "No op found for " << name << std::endl;
return NULL;
}
private:
ProductFactory() {}
~ProductFactory() {}
ProductFactory(const ProductFactory &);
const ProductFactory &operator=(const ProductFactory &);
std::map<std::string, IProductRegistrar*> m_ProductRegistry;
};
// 产品注册类
template <class OpImpl_t>
class ProductRegistrar : public IProductRegistrar {
public:
explicit ProductRegistrar(std::string name) {
ProductFactory::Instance().RegisterProduct(this, name);
}
Op *CreateProduct() {
return new OpImpl_t();
}
};
int main() {
// ========================== 在一个文件执行所有具体类的注册流程 ===========================//
ProductRegistrar<Conv> Conv("conv");
ProductRegistrar<BatchNorm> BatchNorm("batch_norm");
// ========================== 从工厂获取具体类对象 ========================================//
Op* pConv = ProductFactory::Instance().GetProduct("conv");
pConv->Show();
Op* pBatchNorm = ProductFactory::Instance().GetProduct("batch_norm");
pBatchNorm->Show();
// =================================释放资源==============================================//
if (pConv) {
delete pConv;
}
if (pBatchNorm) {
delete pBatchNorm;
}
return 0;
}
//g++ opfactory.cpp -o hello