注册工厂模式解释

典型的使用场景:
当把Onnx 模型解析到自己的框架时,需要遍历onnx模型所有的node,根据它的名字,例如”conv","relu“等关键字,创建相应的自己框架里的算子。

简单工厂模式
如何在自己的源代码里写这样的方法呢?
最原始的是写一堆if,如果node.op_type == “conv”,则创建自己框架的conv, 如果node.op_type == “relu”,则创建自己框架的relu…

把这些逻辑包装到一个工厂类,然后开放出一个create(std::string op_type)接口,就可以自动根据读取到的onnx node optype来创建自己框架相应的算子。这便是简单工厂模式。

但简单工厂模式有个问题,就是在工厂类里要写好key->value的对应关系,一旦出现onnx模型里新增了一个算子(比如”gelu"),用户就不得不修改源代码里的工厂类才能支持该算子。这就不可接受了,因为有时不想把这部分源代码开放出来,更不想要用户这么麻烦去改源代码。

注册工厂模式
此时就可以用到注册工厂模式。用户提供了新算子类的cpp和h文件后,通过一个开放的文件,把约定的名字(例如”gelu“)把这个算子注册到工厂里去,这样源代码在解析node.op_type == "gelu"的时候,就能自动认识这个算子了。

具体的做法是,工厂类里不要保存具体类,这样就避免了新类出现后要改工厂类。但这样一来工厂类该保存什么呢?可以用字典(c++里则是map)保存一个注册类对象,当需要创建具体对象的时候,通过注册类去创建。用字典的原因是,可以根据外界读取的key,自动取出相应的注册类。

接下来需要说明,这个注册类不得不需要一个基类和子类。工厂类里的字典要存放string -> 注册类基类的键值对,这样才能统一适应各种不同的算子。而具体的注册类对象必须得是一个模板,以满足各类算子创建的需求,包括创建用户提供的新类的需求。

注册类里,必须有一个创建函数,用来创建一个具体算子类对象(通常是在堆上创建一个对象,然后返回指针)。这样,当工厂类要创建该对象的时候,就可以通过字典查找到这个注册类,从而调用这个类的创建函数来创建实例。

这样一来,不仅源代码里创建对象的接口完全统一了,被创建的类本身和工厂类也完成了隔离。程序一开始,会进行一个注册动作,把一个文件夹里所有的类都注册到工厂类里的字典,然后到了创建对象的时候,工厂类自动根据读取到的key值,自动创造出不同类的对象。用户要新增加类的时候,只需要提供新类的具体实施,并完成注册(也就是约定它对应的key是什么),和源代码里的工厂类彻底隔离开来。

#include<iostream>
#include<cstring>
#include<map>

class Op {
public:
  virtual void Show() = 0;
  virtual ~Op() {}
};

class Conv : public Op {
public:
  void Show() override {
    std::cout << "我是conv" << std::endl;
  }
};

class BatchNorm : public Op {
public:
  void Show() override {
    std::cout << "我是BatchNorm" << std::endl;
  }
};

// 产品注册接口类
class IProductRegistrar {
public:
  virtual Op *CreateProduct() = 0;

protected:
  // 禁止外部构造和虚构, 子类的"内部"的其他函数可以调用
  IProductRegistrar() {}
  virtual ~IProductRegistrar() {}

private:
  // 禁止外部拷贝和赋值操作
  IProductRegistrar(const IProductRegistrar &);
  const IProductRegistrar &operator=(const IProductRegistrar &);
};

// 工厂类
class ProductFactory {
public:
  static ProductFactory &Instance() {
    static ProductFactory instance;
    return instance;
  }

  void RegisterProduct(IProductRegistrar *registrar, std::string name) {
    m_ProductRegistry[name] = registrar;
  }

  Op *GetProduct(std::string name) {
    if (m_ProductRegistry.find(name) != m_ProductRegistry.end()) {
        return m_ProductRegistry[name]->CreateProduct();
    }

    std::cout << "No op found for " << name << std::endl;
    return NULL;
  }

private:
  ProductFactory() {}
  ~ProductFactory() {}
  ProductFactory(const ProductFactory &);
  const ProductFactory &operator=(const ProductFactory &);

  std::map<std::string, IProductRegistrar*> m_ProductRegistry;
};

// 产品注册类
template <class OpImpl_t>
class ProductRegistrar : public IProductRegistrar {
public:
  explicit ProductRegistrar(std::string name) {
    ProductFactory::Instance().RegisterProduct(this, name);
  }

  Op *CreateProduct() {
    return new OpImpl_t();
  }
};

int main() {
  // ========================== 在一个文件执行所有具体类的注册流程 ===========================//
  ProductRegistrar<Conv> Conv("conv");
  ProductRegistrar<BatchNorm> BatchNorm("batch_norm");
  // ========================== 从工厂获取具体类对象 ========================================//
  Op* pConv = ProductFactory::Instance().GetProduct("conv");
  pConv->Show();
  Op* pBatchNorm = ProductFactory::Instance().GetProduct("batch_norm");
  pBatchNorm->Show();
  // =================================释放资源==============================================//
  if (pConv) {
    delete pConv;
  }

  if (pBatchNorm) {
    delete pBatchNorm;
  }
  return 0;
}

//g++ opfactory.cpp -o hello

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值