onnxruntime opkernel

class KernelDef {
 private:
  // note that input/output might be on CPU implicitly when the node is from CPU execution provider
  constexpr static inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) {
    return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
  }

 public:
  explicit KernelDef() = default;

  const std::string& OpName() const {
    return op_name_;
  }

  const std::string& Domain() const {
    return op_domain_;
  }

  void SinceVersion(/*out*/ int* start, /*out*/ int* end) const {
    *start = op_since_version_start_;
    *end = op_since_version_end_;
  }

 private:
  friend class KernelDefBuilder;

  // The operator name supported by <*this> kernel..
  std::string op_name_;

  // The operator since_version range supported by <*this> kernel.
  // A kernel could support an operator definition between <op_since_version_start>
  // and <op_since_version_end> (inclusive).
  int op_since_version_start_ = 1;
  int op_since_version_end_ = INT_MAX;

  // The operator domain supported by <*this> kernel.
  // Default to 'onnxruntime::kOnnxDomain'.
  // Please note the behavior of std::string("") and std::string() are different
  std::string op_domain_;

  // The type of the execution provider.
  std::string provider_type_;

  // The data types that are supported in this build (enabled) for inputs/outputs.
  // Key is input/output/type constraint name defined in op schema, Value is supported types.
  std::unordered_map<std::string, std::vector<MLDataType>> type_constraints_;

  // An element <i, j> means that output j reuses the memory of input i.
  std::vector<std::pair<int, int>> inplace_map_;

  // An element <i, j> means that output j is an alias of input i.
  std::vector<std::pair<int, int>> alias_map_;

  // This variable stores <input_offset, output_offset> for the variadic alias mapping
  // output 'i + output_offset' is an alias of input 'i + input_offset' for all i >= 0
  std::optional<std::pair<int, int>> variadic_alias_offsets_;

  // Require input tensors to be allocated contiguously.
  bool allocate_inputs_contiguously_ = false;

  // Whether the outputs are from external.
  bool external_outputs_ = false;

#ifdef ENABLE_STRIDED_TENSORS
  // An element i means i-th input can be strided tensor.
  std::vector<int> may_strided_inputs_;

  // An element <i, j> means j-th output can be a strided tensor, which share the data from i-th input.
  std::vector<std::pair<int, int>> may_strided_output_map_;
#endif

  // The memory types of inputs/outputs of this kernel
  MemTypeMap input_memory_type_args_;
  MemTypeMap output_memory_type_args_;

  // execution command queue id, 0 for default queue in execution provider
  int exec_queue_id_ = 0;
  // Default memory type for all inputs
  OrtMemType default_inputs_mem_type_{OrtMemTypeDefault};
  // Default memory type for all outputs
  OrtMemType default_outputs_mem_type_{OrtMemTypeDefault};
};
重点:type_constraints_
template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
  KernelCreateInfo info;
  return info;
}

Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
    static const BuildKernelCreateInfoFn function_table[] = {
        class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Conv);
        class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Conv);
    }
}
Status RegisterCPUKernels(KernelRegistry& kernel_registry) {
  ORT_RETURN_IF_ERROR(RegisterOnnxOperatorKernels(kernel_registry));
#ifndef DISABLE_ML_OPS
  ORT_RETURN_IF_ERROR(::onnxruntime::ml::RegisterOnnxMLOperatorKernels(kernel_registry));
#endif
#ifndef DISABLE_CONTRIB_OPS
  ORT_RETURN_IF_ERROR(::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry));
#endif
#if defined(ENABLE_TRAINING_OPS)
  ORT_RETURN_IF_ERROR(::onnxruntime::contrib::RegisterCpuTrainingKernels(kernel_registry));
#endif
  return Status::OK();
}
  for (auto& function_table_entry : function_table) {
    KernelCreateInfo info = function_table_entry();
    if (info.kernel_def != nullptr) {  // filter disabled entries where type is void
      ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info)));
    }
  }

function_table 上面有定义。

template <>                                                                                                                      \
  KernelCreateInfo                                                                                                                 \
  BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name)>() {                              \
    return KernelCreateInfo(                                                                                                       \
        builder.SetName(#name)                                                                                                     \
            .SetDomain(domain)                                                                                                     \
            .SinceVersion(ver)                                                                                                     \
            .Provider(provider)                                                                                                    \
            .Build(),                                                                                                              \
        static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
  }
class OpKernel {
 public:
  using DoneCallback = std::function<void()>;

  explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(CopyOpKernelInfo(info)) {}
  virtual ~OpKernel() = default;

  const onnxruntime::Node& Node() const;
  const onnxruntime::KernelDef& KernelDef() const;

  [[nodiscard]] virtual Status Compute(_Inout_ OpKernelContext* context) const = 0;

  [[nodiscard]] virtual bool IsAsync() const {
    // by default all kernels are sync version.
    return false;
  }

  [[nodiscard]] virtual Status ComputeAsync(_Inout_ OpKernelContext*, DoneCallback) const {
    ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
  }

  // Override this function to PrePack initialized constant tensor to the format as needed.
  // For example, MatMul kernel can pack the input B if it is constant like code below.
  //   Status PrePack(const Tensor& tensor, int input_idx, /*out*/ bool& is_packed,
  //                  /*out*/ PrePackedWeights* prepacked_weight_for_caching,
  //                  AllocatorPtr alloc) override {
  //     is_packed = false;
  //     if (input_idx == 1) {
  //       is_packed = true;
  //       this.Pack(tensor, this.buffer_, alloc);
  //       if (prepacked_weight_for_caching) {
  //           // LOGIC TO CACHE `this.buffer_` SINCE THE KERNEL DOESN"T OWN THE PACKED WEIGHT
  //       }
  //     }
  //     return Status::OK();
  //   }
  // Please refer to MatMulIntegerToFloatBase for a complete example
  // @param tensor: The initialized constant tensor
  // @param input_idx: The input index of the tensor in this kernel
  // @param alloc: The kernel's PrePack() method MUST use this allocator for allocating the pre-packed
  //               weights' buffers. The alloc that the PrePack() method will receive will be either
  //               the allocator tied to the session if the kernel owns the pre-packed buffer or an
  //               allocator shared between sessions if the pre-packed buffer is to be shared across sessions
  //               (i.e.) the kernel does not own the buffer.
  // @param is_packed: Set it to true if the kernel packed the tensor or to false
  //                   The kernel is responsible for keeping the packed data and related metadata if is_packed is true,
  //                   and the original initialized constant tensor will be released and not accessible anymore in
  //                   the Compute function.
  // @param prepacked_weights: A PrePackedWeights instance will be provided to the kernel IF the pre-packed weights
  //                          are meant to be stored in a shared container.

  virtual Status
  PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
          /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) {
    is_packed = false;
    return Status::OK();
  }

  // Override this function to use provided pre-packed weight.
  // Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
  //                                 int input_idx,
  //                                 /*out*/ bool& used_shared_buffers) {
  //     used_shared_buffers = true;
  //     this.buffer_ = std::move(prepacked_buffers[0]);
  //     return Status::OK();
  //   }
  // Please refer to MatMulIntegerToFloatBase for a complete example
  // @param prepacked_buffers: The pre-packed buffers to be used by this kernel for the provided input index
  //                           (Sometimes a single constant initializer may have multiple pre-packed buffers associated
  //                            with it and it upto the kernel developer to store it in any order of their choice in PrePack()
  //                            and must use the same order for retrieval in UseSharedPrePackedBuffers().
  // @param input_idx: The input index of the tensor in this kernel
  // @param used_shared_buffers: Boolean flag set by the kernel implementation indicating
  // that the provided weight has been used by the kernel.
  virtual Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& /*prepacked_buffers*/,
                                           int /*input_idx*/,
                                           /*out*/ bool& used_shared_buffers) {
    used_shared_buffers = false;
    return Status::OK();
  }

  const OrtMemoryInfo& Allocator(int id, OrtMemType mem_type) const;
  const OpKernelInfo& Info() const {
    return *op_kernel_info_;
  }

 private:
  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel);
  std::unique_ptr<OpKernelInfo> op_kernel_info_;
};

OpKernel 包含了 OpKernelInfo。

ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
    Slice,
    1, 9,
    KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<EnabledDataTypes>()),
    Slice1);

最后的Slice10为真正的OpKernel. 宏展开为:

class kCpuExecutionProvider_Slice_kOnnxDomain_ver1_9; 
template <> KernelCreateInfo BuildKernelCreateInfo<kCpuExecutionProvider_Slice_kOnnxDomain_ver1_9>() { 
        return KernelCreateInfo( 
            KernelDefBuilder()
            .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<EnabledDataTypes>())
            .SetName("Slice") 
            .SetDomain(kOnnxDomain) 
            .SinceVersion(1, 9) 
            .Provider(kCpuExecutionProvider) 
            .Build(), 
            static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<Slice1>(info); return Status::OK(); })
        ); 
}

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
  KernelCreateInfo info;
  return info;
}

BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Slice)>

=> BuildKernelCreateInfo<kCpuExecutionProvider_Slice_kOnnxDomain_ver13>





ONNX_CPU_OPERATOR_KERNEL(
    Slice,
    13,
    KernelDefBuilder()
        .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<EnabledDataTypes>())
        .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<EnabledIndicesTypes>()),
    Slice10);


class kCpuExecutionProvider_Slice_kOnnxDomain_ver13; 
template <> KernelCreateInfo BuildKernelCreateInfo<kCpuExecutionProvider_Slice_kOnnxDomain_ver13>() { 
    return KernelCreateInfo( 
        KernelDefBuilder() 
        .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<EnabledDataTypes>()) 
        .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<EnabledIndicesTypes>())
        .SetName("Slice") 
        .SetDomain(kOnnxDomain) 
        .SinceVersion(13) 
        .Provider(kCpuExecutionProvider) 
        .Build(), 
        static_cast<KernelCreatePtrFn>( [](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<Slice10>(info); return Status::OK(); })
    );
}

// Kernel create function map from op name to kernel creation info.

// key is opname+domain_name+provider_nameeg. 

// eg. Key = "Clip ai.onnx CPUExecutionProvider"

KernelCreateMap kernel_creator_fn_map_;

// KernelCreateInfo for each node so we do kernel lookup once KernelCreateInfoMap kernel_create_info_map_;

// cache of the constructed kernels to avoid spending construction time per executor std::vector<std::unique_ptr<OpKernel>> session_kernels_;

// assumes vector is already resize()'ed to the number of nodes in the graph
ORT_RETURN_IF_ERROR(kernel_registry_manager.CreateKernel(node, exec_provider, *this, kci, session_kernels_[node.Index()]));
namespace onnxruntime {
Status KernelRegistryManager::CreateKernel(const Node& node,
                                           const IExecutionProvider& execution_provider,
                                           SessionState& session_state,
                                           const KernelCreateInfo& kernel_create_info,
                                           std::unique_ptr<OpKernel>& out) const {
  OpKernelInfo kernel_info(node, *kernel_create_info.kernel_def, execution_provider,
                           session_state.GetConstantInitializedTensors(),
                           session_state.GetOrtValueNameIdxMap(),
                           session_state.GetDataTransferMgr());

  return kernel_create_info.kernel_create_func(session_state.GetMutableFuncMgr(), kernel_info, out);
}

不支持f16:

"This op has been implemented only for the following types (tensor(float),), but the node in the model has the following type (tensor(float16))"

ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
    Conv,
    1, 10,
    KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
    Conv<float>);

ONNX_CPU_OPERATOR_KERNEL(
    Conv,
    11,
    KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
    Conv<float>);

} 

可见Conv只有实现float.

/**
 * \brief PrimitiveDataTypeBase
 *        Base class for primitive Tensor contained types
 *
 * \details This class contains an integer constant that can be
 *          used for input data type dispatching
 *
 */
class PrimitiveDataTypeBase : public DataTypeImpl {
 public:
  bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
    return false;
  }

  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const final {
    return nullptr;
  }

  int32_t GetDataType() const {
    return data_type_;
  }

 protected:
  PrimitiveDataTypeBase(size_t size, int32_t data_type)
      : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type} {}

 private:
  const int32_t data_type_;
};

/**
 * \brief PrimitiveDataType
 *        Typed specialization for primitive types.
 *        Concrete instances of this class are used by Tensor.
 *
 * \param T - primitive data type
 *
 */
template <typename T>
class PrimitiveDataType : public PrimitiveDataTypeBase {
 private:
  static void Delete(void* p) {
    delete static_cast<T*>(p);
  }

 public:
  static MLDataType Type();

  DeleteFunc GetDeleteFunc() const override {
    return &Delete;
  }

 private:
  PrimitiveDataType()
      : PrimitiveDataTypeBase{sizeof(T),
                              utils::ToTensorProtoElementType<T>()} {
  }
};

template <>
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<MLFloat16>() {
  return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
}

enum TensorProto_DataType : int {
  TensorProto_DataType_UNDEFINED = 0,
  TensorProto_DataType_FLOAT = 1,
  TensorProto_DataType_UINT8 = 2,
  TensorProto_DataType_INT8 = 3,
  TensorProto_DataType_UINT16 = 4,
  TensorProto_DataType_INT16 = 5,
  TensorProto_DataType_INT32 = 6,
  TensorProto_DataType_INT64 = 7,
  TensorProto_DataType_STRING = 8,
  TensorProto_DataType_BOOL = 9,
  TensorProto_DataType_FLOAT16 = 10,
  TensorProto_DataType_DOUBLE = 11,
  TensorProto_DataType_UINT32 = 12,
  TensorProto_DataType_UINT64 = 13,
  TensorProto_DataType_COMPLEX64 = 14,
  TensorProto_DataType_COMPLEX128 = 15,
  TensorProto_DataType_BFLOAT16 = 16
};

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值