c++注册类方法实现单例映射

概述

在caffe和阿里MNN中,广泛使用了注册方法,简化了代码的复杂程度。下面我们简单梳理一下。

分析

  1. 要实现一套多种计算架构使用的代码,如cpu、gpu等。
  2. 要在每个计算架构下,实现各种算子。
  3. 使用一套接口实现。
    要达到以上要求,需要给各部分设置通用的接口,利用c++类的继承,能让所有的计算架构或算子具有相同的接口,接口一致,有利于代码的简化。

解析

异构平台后端Backend和{Devices}Backend分别拥有自己的生成器和单实例管理的map。

定义通用结构

  1. 平台
typedef enum {
    MNN_FORWARD_CPU = 0,
    /*
     Firtly find the first available backends not equal to CPU
     If no other backends, use cpu
     */
    MNN_FORWARD_AUTO = 4,
    /*Hand write metal*/
    MNN_FORWARD_METAL = 1,
    /*Use IOS's MPS instead of hand-write metal, Not Support yet*/
    MNN_FORWARD_MPS = 2,
    /*Android / Common Device GPU API*/
    MNN_FORWARD_OPENCL = 3,
    MNN_FORWARD_OPENGL = 6,
    MNN_FORWARD_VULKAN = 7,
    MNN_FORWARD_ALL
} MNNForwardType;
  1. 算子
enum OpType
{
    OpType_MatMul = 0......
};

接口类

  1. 单例基类
    NonCopyable.h
#ifndef NON_COPY_ABLE_H
#define NON_COPY_ABLE_H
class NonCopyable {
public:
    NonCopyable()                    = default;
    NonCopyable(const NonCopyable&)  = delete;
    NonCopyable(const NonCopyable&&) = delete;
    NonCopyable& operator=(const NonCopyable&) = delete;
    NonCopyable& operator=(const NonCopyable&&) = delete;
};

#endif // NON_COPY_ABLE_H

Execution.h

#ifndef Execution_hpp
#define Execution_hpp
#include "NonCopyable.h"
#include <vector>
class Backend;
// 实例通用接口类,里面封装了Backend,Backend为异构计算的通用接口类
class Execution : public NonCopyable
{
public:
    Execution() = delete;
    Execution(Backend *backend) : mBackEnd(backend)
    {
        // nothing to do
    }
    virtual ~Execution() = default;

    virtual int onExecute(const std::vector<float> &inputs, const std::vector<float> &outputs) = 0;
    
    Backend *backend() const
    {
        return mBackEnd;
    }

public:
    class Creator : public NonCopyable
    {
    public:
        virtual ~Creator() = default;
        virtual Execution *onCreate(Backend *backend) const = 0;
    };

private:
    Backend *mBackEnd;
};

#endif /* Execution_hpp */
  1. 后端接口类
    Backend.h
#ifndef Backend_hpp
#define Backend_hpp
#include "MNNForwardType.h"
#include "NonCopyable.h"
#include "Execution.h"
#include "MNN_generated.h"
#include <vector>
// 这个类是后端所有异构计算的通用接口类,后端的基类,是个抽象类
class Backend : NonCopyable
{
public:
    struct Info
    {
        /** forward type. */
        MNNForwardType type = MNN_FORWARD_CPU;
        /** for CPU only. number of threads. */
        int numThread = 4;
        enum Mode
        {
            // The Op will be run in execution->onExecute
            DIRECT = 0,
            // The Op will be recorded. Run in onExecuteBegin and Wait in onExecuteEnd
            INDIRECT = 1
        };
        Mode mode = DIRECT;
    };
public:
    Backend(MNNForwardType type) : mType(type) {}
    virtual ~Backend() = default;
public:
    virtual Execution *onCreate(const std::vector<float> &inputs, const std::vector<float> &outputs, OpType type) = 0;
public:
    inline MNNForwardType type() const {
        return mType;
    }
private:
    const MNNForwardType mType;
};
// 后端生成器类,是个抽象类
class BackendCreator
{
public:
    virtual ~BackendCreator() = default;

    virtual Backend *onCreate(const Backend::Info &info) const = 0;

    virtual bool onValid(Backend::Info &info) const
    {
        info.mode = Backend::Info::DIRECT;
        return true;
    }
protected:
    BackendCreator() = default;
};
const BackendCreator* MNNGetExtraBackendCreator(MNNForwardType type);

bool MNNInsertExtraBackendCreator(MNNForwardType type, const BackendCreator* creator,
                                             bool needCheck = false);
#endif /* Backend_hpp */

Backend.cpp

#include "Backend.h"
#include <map>
#include <mutex>
#include <memory>
extern void registerCPUBackendCreator();
// 默认注册CPU后端
void registerBackend()
{
    static std::once_flag s_flag;
    std::call_once(s_flag, [&]() { registerCPUBackendCreator(); });
}
// 生成后端生实例成器单实例
static std::map<MNNForwardType, std::pair<const BackendCreator *, bool>> &GetExtraCreator()
{
    static std::once_flag flag;
    static std::map<MNNForwardType, std::pair<const BackendCreator *, bool>> *gExtraCreator;
    std::call_once(flag,
                   [&]() { gExtraCreator = new std::map<MNNForwardType, std::pair<const BackendCreator *, bool>>; });
    return *gExtraCreator;
}
// 获取后端生成器
const BackendCreator *MNNGetExtraBackendCreator(MNNForwardType type)
{
    registerBackend();

    auto &gExtraCreator = GetExtraCreator();
    auto iter = gExtraCreator.find(type);
    if (iter == gExtraCreator.end())
    {
        return nullptr;
    }
    if (!iter->second.second)
    {
        return iter->second.first;
    }
    Backend::Info info;
    info.type = type;
    std::shared_ptr<Backend> bn(iter->second.first->onCreate(info));
    if (nullptr != bn.get())
    {
        return iter->second.first;
    }
    return nullptr;
}
// 添加后端生成器
bool MNNInsertExtraBackendCreator(MNNForwardType type, const BackendCreator* creator, bool needCheck) {
    auto& gExtraCreator = GetExtraCreator();
    if (gExtraCreator.find(type) != gExtraCreator.end()) {
        return false;
    }
    gExtraCreator.insert(std::make_pair(type, std::make_pair(creator, needCheck)));
    return true;
}

计算架构

下面以cpu为例
CPUBackend.h

#ifndef CPUBackend_hpp
#define CPUBackend_hpp
#include "Backend.h"
// CPU后端,集成与后端Backend
class CPUBackend final : public Backend
{
public:
    CPUBackend(int numberThread = 4);
    virtual ~CPUBackend();

    // 创建实例
    virtual Execution *onCreate(const std::vector<float> &inputs, const std::vector<float> &outputs, OpType type) override;

public:
	// 实例生成器,是个抽象基类,是ops的生成器基类,并设置想得创建接口
    class Creator 
    {
    public:
        virtual Execution *onCreate(const std::vector<float> &inputs, const std::vector<float> &outputs,
                                    OpType type, Backend *backend) const = 0;
    };
    // 向map中添加相应的ops
    static bool addCreator(OpType t, Creator *c);

private:
    int mThreadNumber;
};
// ops注册类,T为ops对应的类
template <class T>
class CPUCreatorRegister {
public:
    CPUCreatorRegister(OpType type) {
        CPUBackend::addCreator(type, new T);
    }
};
// 通过宏定义来注册
#define REGISTER_CPU_OP_CREATOR(name, opType) static CPUCreatorRegister<name> _Create##opType(opType)

#endif /* CPUBackend_hpp */

CPUBackend.cpp

#include "CPUBackend.h"
#include <stdio.h>
#include <map>
#include <mutex>
// 创建map,用于管理ops单实例
static inline std::map<OpType, CPUBackend::Creator *> *getCreatorMap()
{
    static std::once_flag of;
    static std::map<OpType, CPUBackend::Creator *> *ret = nullptr;
    std::call_once(of, [&]() { ret = new std::map<OpType, CPUBackend::Creator *>; });
    return ret;
}
// 向map中添加ops
bool CPUBackend::addCreator(OpType t, Creator *c)
{
    auto map = getCreatorMap();
    if (map->find(t) != map->end())
    {
        printf("Error: %d type has be added\n", t);
        return false;
    }
    map->insert(std::make_pair(t, c));
    return true;
}

CPUBackend::CPUBackend(int numberThread)
    : Backend(MNN_FORWARD_CPU), mThreadNumber(numberThread)
{
    mThreadNumber = std::max(1, mThreadNumber);
    ;
}

CPUBackend::~CPUBackend()
{
}
// 使用map中的生成器,创建对应的ops实例
Execution *CPUBackend::onCreate(const std::vector<float> &inputs, const std::vector<float> &outputs, OpType type)
{
    auto map = getCreatorMap();
    auto iter = map->find(type);
    if (iter == map->end())
    {
        printf("Don't support type\n");
        return nullptr;
    }

    auto exe = iter->second->onCreate(inputs, outputs, type, this);
    if (nullptr == exe)
    {
        printf("The Creator Don't support type\n");
        return nullptr;
    }
    return exe;
}
// CPU后端生成器
struct CPUBackendCreator : BackendCreator
{
    Backend *onCreate(const Backend::Info &info) const override
    {
        return new CPUBackend(info.numThread);
    }
};
// 注册CPU后端生成器到后端的map生成器管理器中
void registerCPUBackendCreator()
{
    MNNInsertExtraBackendCreator(MNN_FORWARD_CPU, new CPUBackendCreator);
};

算子

CPUMatMul.h


#ifndef CPUBatchMatMul_hpp
#define CPUBatchMatMul_hpp
#include "Execution.h"
// ops类
class CPUMatMul : public Execution
{
public:
    CPUMatMul(Backend *backend);
    virtual ~CPUMatMul() = default;

    virtual int onExecute(const std::vector<float> &inputs, const std::vector<float> &outputs) override;

private:
};

#endif /* CPUBatchMatMul_hpp */

CPUMatMul.cpp

#include "CPUMatMul.h"
#include "CPUBackend.h"

CPUMatMul::CPUMatMul(Backend *backend) : Execution(backend)
{
}
// 执行ops
int CPUMatMul::onExecute(const std::vector<float> &inputs, const std::vector<float> &outputs)
{//TODO
    return 0;
}
// ops生成器
class CPUMatMulCreator : public CPUBackend::Creator
{
public:
    virtual Execution *onCreate(const std::vector<float> &inputs, const std::vector<float> &outputs,
                                OpType type, Backend *backend) const override
    {
        return new CPUMatMul(backend);
    }
};
//注册ops
REGISTER_CPU_OP_CREATOR(CPUMatMulCreator, OpType_MatMul);
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值