Onnx之Elemwise类算子的执行(CLIP为例)

Clip算子的broad_cast_type类型为BcastType::UNKNOWN_BCAST_TYPE.所以会调用到

fallback::ElemwiseImpl::exec(srcs, dst); 然后因为
srcs.size() > 2,就调用到:
naive::ElemwiseForwardImpl::exec(srcs, dst);

先看下调用堆栈图:

         

     Elemwise类首先定义算子执行分配器ModeDispatcher, 通过on_arity_dispatched_cb_dtype 进行调用,实际上由方法ElemwiseForwardImpl::on_arity_dispatched和ElemwiseForwardImpl::on_arity_dispatched_no_bool 调用。

#define on_arity_dispatched_cb_dtype(_dt)                                       \
    if (m_dst->layout.dtype == _dt()) {                                         \
        using dtrait = DTypeTrait<_dt>;                                         \
        using ctype = dtrait::ctype;                                            \
        return ModeDispatcher<arity, dtrait::category, ctype>::run(             \
                static_cast<HandleImpl*>(handle()), src, m_param.mode, *m_dst); \
    }



template <int arity>
void ElemwiseForwardImpl::on_arity_dispatched() {
    printf("********************** on_arity_dispatched\n");
    auto src = make_elemwise_op_param<arity>();
    MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype)
    MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype)
    on_arity_dispatched_cb_dtype(::megdnn::dtype::Bool) megdnn_throw("bad dtype");
}

template <int arity>
void ElemwiseForwardImpl::on_arity_dispatched_no_bool() {
    printf("********************** on_arity_dispatched_no_bool\n");
    auto src = make_elemwise_op_param<arity>();
    MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype)
    MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype)
    megdnn_throw("bad dtype");
}

以下为ModeDispatcher的定义:

#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT
IMPL_MODE_DISPATCHER(3, DTypeCategory::FLOAT);
#define IMPL_MODE_DISPATCHER(_arity, _dtype_cat)                                   \
    template <typename ctype>                                                      \
    struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> {        \
        static constexpr int arity = _arity;                                       \
        static void run(                                                           \
                HandleImpl* handle, const ElemwiseOpParamN<arity>& src, Mode mode, \
                const TensorND dst) {                                              \
            switch (mode) {                                                        \
                FOREACH(_cb_dispatch_mode)                                         \
                default:                                                           \
                    megdnn_throw("bad mode");                                      \
            }                                                                      \
        }                                                                          \
    }

#undef FOREACH

#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \
    MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb)      \
    MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)       \
    MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb)     \
    MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb)              \
    MEGDNN_ELEMWISE_MODE_ENABLE(PRELU_GRAD, cb)

#define MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb) _cb(_mode)

扩展代码后: 



template <typename ctype>                                                      
struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> {        
    static constexpr int arity = _arity;                                       
    static void run(                                                           
                HandleImpl* handle, const ElemwiseOpParamN<arity>& src, Mode mode, 
               const TensorND dst) {                                              
          switch (mode) {                                                        
                FOREACH(_cb_dispatch_mode)                                         
                default:                                                           
                    megdnn_throw("bad mode");                                      
    }                                                                      
}   

展开为:

template <typename ctype>                                                      
struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> {        
    static constexpr int arity = _arity;                                       
    static void run(                                                           
                HandleImpl* handle, const ElemwiseOpParamN<arity>& src, Mode mode, 
               const TensorND dst) {                                              
          switch (mode) {                                                        
                _cb_dispatch_mode(COND_LEQ_MOV)      
                _cb_dispatch_mode(COND_LT_MOV)       
                _cb_dispatch_mode(FUSE_MUL_ADD3)     
                _cb_dispatch_mode(CLIP)              
                _cb_dispatch_mode(PRELU_GRAD)                                   
                default:                                                           
                    megdnn_throw("bad mode");                                      
    }                                                                      
}   

#define _cb_dispatch_mode(_m)                                                          \
    case Mode::_m:                                                                     \
        do {                                                                           \
            using KernImpl = ElemwiseKern<                                             \
                    megcorePlatformCPU, param_enumv::Elemwise::Mode::_m, ctype>;       \
            MIDOUT_BEGIN(                                                              \
                    megdnn_naive_elemwise,                                             \
                    midout_iv(param_enumv::Elemwise::Mode::_m)) {                      \
                auto params = src;                                                     \
                MEGDNN_DISPATCH_CPU_KERN(                                              \
                        handle, ElemArithKernCaller<arity MEGDNN_COMMA KernImpl>::run( \
                                        dst.ptr<ctype>(), params));                    \
                return;                                                                \
            }                                                                          \
            MIDOUT_END();                                                              \
        } while (0);

看看ElemwiseKern之CLIP算子的定义

struct ElemwiseKern;



DEF_KERN_ALL(CLIP, x <= y ? y : (x <= z ? x : z));

//! define kernel for all ctypes
#define DEF_KERN_ALL(_mode, _imp) \
    DEF_KERN_INT(_mode, _imp);    \
    DEF_KERN_FLOAT(_mode, _imp);

//! define kernel for all float types
#define DEF_KERN_FLOAT(_mode, _imp)                     \
    DEF_KERN(dt_float32, _mode, _imp);                  \
    DNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
    DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)

//! define kernel for all int types
#define DEF_KERN_INT(_mode, _imp)    \
    DEF_KERN(dt_int32, _mode, _imp); \
    DEF_KERN(dt_int16, _mode, _imp); \
    DEF_KERN(dt_int8, _mode, _imp);  \
    DEF_KERN(dt_uint8, _mode, _imp);


//! define kernel for a single ctype
#define DEF_KERN(_ctype, _mode, _imp)                                             \
    template <megcorePlatform_t plat>                                             \
    struct ElemwiseKern<plat, param_enumv::Elemwise::Mode::_mode, _ctype> {       \
        typedef _ctype ctype;                                                     \
        static __host__ __device__ _ctype apply(KERN_SIG) { return ctype(_imp); } \
    }

 ElemArithKernCaller 为实际算子执行代码:

        其中KernImpl就是前面定义的 ElemwiseKern, 通过MEGDNN_DISPATCH_CPU_KERN,有CPU调度执行

/*!
 * \brief operator impls should utilize this method to
 * \param _handle a pointer to HandleImpl
 * \param _stmt the statements to be executed for the kernel
 */
#define MEGDNN_DISPATCH_CPU_KERN(_handle, _stmt) \
    do {                                         \
        auto _kern = [=]() { _stmt; };           \
        _handle->dispatch_kern(_kern);           \
    } while (0)


template <int arity, class KernImpl>
struct ElemArithKernCaller {
    typedef typename KernImpl::ctype ctype;
    static void run(ctype* dest, const ElemwiseOpParamN<arity>& param);
};


template <class KernImpl>
struct ElemArithKernCaller<1, KernImpl> {
    typedef typename KernImpl::ctype ctype;
    static void run(ctype* dest, const ElemwiseOpParamN<1>& param) {
        auto iter0 = tensor_iter_valonly<ctype>(param[0]).begin();
        for (size_t i = 0; i < param.size; ++i) {
            dest[i] = KernImpl::apply(*iter0);
            ++iter0;
        }
    }
};
template <class KernImpl>
struct ElemArithKernCaller<2, KernImpl> {
    typedef typename KernImpl::ctype ctype;
    static void run(ctype* dest, const ElemwiseOpParamN<2>& param) {
        auto iter0 = tensor_iter_valonly<ctype>(param[0]).begin();
        auto iter1 = tensor_iter_valonly<ctype>(param[1]).begin();
        for (size_t i = 0; i < param.size; ++i) {
            dest[i] = KernImpl::apply(*iter0, *iter1);
            ++iter0;
            ++iter1;
        }
    }
};
template <class KernImpl>
struct ElemArithKernCaller<3, KernImpl> {
    typedef typename KernImpl::ctype ctype;
    static void run(ctype* dest, const ElemwiseOpParamN<3>& param) {
        auto iter0 = tensor_iter_valonly<ctype>(param[0]).begin();
        auto iter1 = tensor_iter_valonly<ctype>(param[1]).begin();
        auto iter2 = tensor_iter_valonly<ctype>(param[2]).begin();
        for (size_t i = 0; i < param.size; ++i) {
            dest[i] = KernImpl::apply(*iter0, *iter1, *iter2);
            ++iter0;
            ++iter1;
            ++iter2;
        }
    }
};

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值