class KernelDef {
private:
// note that input/output might be on CPU implicitly when the node is from CPU execution provider
constexpr static inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) {
return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
}
public:
explicit KernelDef() = default;
const std::string& OpName() const {
return op_name_;
}
const std::string& Domain() const {
return op_domain_;
}
void SinceVersion(/*out*/ int* start, /*out*/ int* end) const {
*start = op_since_version_start_;
*end = op_since_version_end_;
}
private:
friend class KernelDefBuilder;
// The operator name supported by <*this> kernel..
std::string op_name_;
// The operator since_version range supported by <*this> kernel.
// A kernel could support an operator definition between <op_since_version_start>
// and <op_since_version_end> (inclusive).
int op_since_version_start_ = 1;
int op_since_version_end_ = INT_MAX;
// The operator domain supported by <*this> kernel.
// Default to 'onnxruntime::kOnnxDomain'.
// Please note the behavior of std::string("") and std::string() are different
std::string op_domain_;
// The type of the execution provider.
std::string provider_type_;
// The data types that are supported in this build (enabled) for inputs/outputs.
// Key is input/output/type constraint name defined in op schema, Value is supported types.
std::unordered_map<std::string, std::vector<MLDataType>> type_constraints_;
// An element <i, j> means that output j reuses the memory of input i.
std::vector<std::pair<int, int>> inplace_map_;
// An element <i, j> means that output j is an alias of input i.
std::vector<std::pair<int, int>> alias_map_;
// This variable stores <input_offset, output_offset> for the variadic alias mapping
// output 'i + output_offset' is an alias of input 'i + input_offset' for all i >= 0
std::optional<std::pair<int, int>> variadic_alias_offsets_;
// Require input tensors to be allocated contiguously.
bool allocate_inputs_contiguously_ = false;
// Whether the outputs are from external.
bool external_outputs_ = false;
#ifdef ENABLE_STRIDED_TENSORS
// An element i means i-th input can be strided tensor.
std::vector<int> may_strided_inputs_;
// An element <i, j> means j-th output can be a strided tensor, which share the data from i-th input.
std::vector<std::pair<int, int>> may_strided_output_map_;
#endif
// The memory types of inputs/outputs of this kernel
MemTypeMap input_memory_type_args_;
MemTypeMap output_memory_type_args_;
// execution command queue id, 0 for default queue in execution provider
int exec_queue_id_ = 0;
// Default memory type for all inputs
OrtMemType default_inputs_mem_type_{OrtMemTypeDefault};
// Default memory type for all outputs
OrtMemType default_outputs_mem_type_{OrtMemTypeDefault};
};
重点:type_constraints_
template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
KernelCreateInfo info;
return info;
}
Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Conv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Conv);
}
}
Status RegisterCPUKernels(KernelRegistry& kernel_registry) {
ORT_RETURN_IF_ERROR(RegisterOnnxOperatorKernels(kernel_registry));
#ifndef DISABLE_ML_OPS
ORT_RETURN_IF_ERROR(::onnxruntime::ml::RegisterOnnxMLOperatorKernels(kernel_registry));
#endif
#ifndef DISABLE_CONTRIB_OPS
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry));
#endif
#if defined(ENABLE_TRAINING_OPS)
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::RegisterCpuTrainingKernels(kernel_registry));
#endif
return Status::OK();
}
for (auto& function_table_entry : function_table) {
KernelCreateInfo info = function_table_entry();
if (info.kernel_def != nullptr) { // filter disabled entries where type is void
ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info)));
}
}
function_table 上面有定义。
template <> \
KernelCreateInfo \
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name)>() { \
return KernelCreateInfo( \
builder.SetName(#name) \
.SetDomain(domain) \
.SinceVersion(ver) \
.Provider(provider) \
.Build(), \
static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
}
class OpKernel {
public:
using DoneCallback = std::function<void()>;
explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(CopyOpKernelInfo(info)) {}
virtual ~OpKernel() = default;
const onnxruntime::Node& Node() const;
const onnxruntime::KernelDef& KernelDef() const;
[[nodiscard]] virtual Status Compute(_Inout_ OpKernelContext* context) const = 0;
[[nodiscard]] virtual bool IsAsync() const {
// by default all kernels are sync version.
return false;
}
[[nodiscard]] virtual Status ComputeAsync(_Inout_ OpKernelContext*, DoneCallback) const {
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
}
// Override this function to PrePack initialized constant tensor to the format as needed.
// For example, MatMul kernel can pack the input B if it is constant like code below.
// Status PrePack(const Tensor& tensor, int input_idx, /*out*/ bool& is_packed,
// /*out*/ PrePackedWeights* prepacked_weight_for_caching,
// AllocatorPtr alloc) override {
// is_packed = false;
// if (input_idx == 1) {
// is_packed = true;
// this.Pack(tensor, this.buffer_, alloc);
// if (prepacked_weight_for_caching) {
// // LOGIC TO CACHE `this.buffer_` SINCE THE KERNEL DOESN"T OWN THE PACKED WEIGHT
// }
// }
// return Status::OK();
// }
// Please refer to MatMulIntegerToFloatBase for a complete example
// @param tensor: The initialized constant tensor
// @param input_idx: The input index of the tensor in this kernel
// @param alloc: The kernel's PrePack() method MUST use this allocator for allocating the pre-packed
// weights' buffers. The alloc that the PrePack() method will receive will be either
// the allocator tied to the session if the kernel owns the pre-packed buffer or an
// allocator shared between sessions if the pre-packed buffer is to be shared across sessions
// (i.e.) the kernel does not own the buffer.
// @param is_packed: Set it to true if the kernel packed the tensor or to false
// The kernel is responsible for keeping the packed data and related metadata if is_packed is true,
// and the original initialized constant tensor will be released and not accessible anymore in
// the Compute function.
// @param prepacked_weights: A PrePackedWeights instance will be provided to the kernel IF the pre-packed weights
// are meant to be stored in a shared container.
virtual Status
PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
/*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) {
is_packed = false;
return Status::OK();
}
// Override this function to use provided pre-packed weight.
// Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
// int input_idx,
// /*out*/ bool& used_shared_buffers) {
// used_shared_buffers = true;
// this.buffer_ = std::move(prepacked_buffers[0]);
// return Status::OK();
// }
// Please refer to MatMulIntegerToFloatBase for a complete example
// @param prepacked_buffers: The pre-packed buffers to be used by this kernel for the provided input index
// (Sometimes a single constant initializer may have multiple pre-packed buffers associated
// with it and it upto the kernel developer to store it in any order of their choice in PrePack()
// and must use the same order for retrieval in UseSharedPrePackedBuffers().
// @param input_idx: The input index of the tensor in this kernel
// @param used_shared_buffers: Boolean flag set by the kernel implementation indicating
// that the provided weight has been used by the kernel.
virtual Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& /*prepacked_buffers*/,
int /*input_idx*/,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
return Status::OK();
}
const OrtMemoryInfo& Allocator(int id, OrtMemType mem_type) const;
const OpKernelInfo& Info() const {
return *op_kernel_info_;
}
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel);
std::unique_ptr<OpKernelInfo> op_kernel_info_;
};
OpKernel 包含了 OpKernelInfo。
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Slice,
1, 9,
KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<EnabledDataTypes>()),
Slice1);
最后的Slice10为真正的OpKernel. 宏展开为:
class kCpuExecutionProvider_Slice_kOnnxDomain_ver1_9;
template <> KernelCreateInfo BuildKernelCreateInfo<kCpuExecutionProvider_Slice_kOnnxDomain_ver1_9>() {
return KernelCreateInfo(
KernelDefBuilder()
.TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<EnabledDataTypes>())
.SetName("Slice")
.SetDomain(kOnnxDomain)
.SinceVersion(1, 9)
.Provider(kCpuExecutionProvider)
.Build(),
static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<Slice1>(info); return Status::OK(); })
);
}
template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
KernelCreateInfo info;
return info;
}
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Slice)>
=> BuildKernelCreateInfo<kCpuExecutionProvider_Slice_kOnnxDomain_ver13>
ONNX_CPU_OPERATOR_KERNEL(
Slice,
13,
KernelDefBuilder()
.TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<EnabledDataTypes>())
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<EnabledIndicesTypes>()),
Slice10);
class kCpuExecutionProvider_Slice_kOnnxDomain_ver13;
template <> KernelCreateInfo BuildKernelCreateInfo<kCpuExecutionProvider_Slice_kOnnxDomain_ver13>() {
return KernelCreateInfo(
KernelDefBuilder()
.TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<EnabledDataTypes>())
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<EnabledIndicesTypes>())
.SetName("Slice")
.SetDomain(kOnnxDomain)
.SinceVersion(13)
.Provider(kCpuExecutionProvider)
.Build(),
static_cast<KernelCreatePtrFn>( [](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<Slice10>(info); return Status::OK(); })
);
}
// Kernel create function map from op name to kernel creation info.
// key is opname+domain_name+provider_nameeg.
// eg. Key = "Clip ai.onnx CPUExecutionProvider"
KernelCreateMap kernel_creator_fn_map_;
// KernelCreateInfo for each node so we do kernel lookup once KernelCreateInfoMap kernel_create_info_map_;
// cache of the constructed kernels to avoid spending construction time per executor std::vector<std::unique_ptr<OpKernel>> session_kernels_;
// assumes vector is already resize()'ed to the number of nodes in the graph
ORT_RETURN_IF_ERROR(kernel_registry_manager.CreateKernel(node, exec_provider, *this, kci, session_kernels_[node.Index()]));
namespace onnxruntime {
Status KernelRegistryManager::CreateKernel(const Node& node,
const IExecutionProvider& execution_provider,
SessionState& session_state,
const KernelCreateInfo& kernel_create_info,
std::unique_ptr<OpKernel>& out) const {
OpKernelInfo kernel_info(node, *kernel_create_info.kernel_def, execution_provider,
session_state.GetConstantInitializedTensors(),
session_state.GetOrtValueNameIdxMap(),
session_state.GetDataTransferMgr());
return kernel_create_info.kernel_create_func(session_state.GetMutableFuncMgr(), kernel_info, out);
}
不支持f16:
"This op has been implemented only for the following types (tensor(float),), but the node in the model has the following type (tensor(float16))"
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Conv,
1, 10,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Conv<float>);
ONNX_CPU_OPERATOR_KERNEL(
Conv,
11,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Conv<float>);
}
可见Conv只有实现float.
/**
* \brief PrimitiveDataTypeBase
* Base class for primitive Tensor contained types
*
* \details This class contains an integer constant that can be
* used for input data type dispatching
*
*/
class PrimitiveDataTypeBase : public DataTypeImpl {
public:
bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
return false;
}
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const final {
return nullptr;
}
int32_t GetDataType() const {
return data_type_;
}
protected:
PrimitiveDataTypeBase(size_t size, int32_t data_type)
: DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type} {}
private:
const int32_t data_type_;
};
/**
* \brief PrimitiveDataType
* Typed specialization for primitive types.
* Concrete instances of this class are used by Tensor.
*
* \param T - primitive data type
*
*/
template <typename T>
class PrimitiveDataType : public PrimitiveDataTypeBase {
private:
static void Delete(void* p) {
delete static_cast<T*>(p);
}
public:
static MLDataType Type();
DeleteFunc GetDeleteFunc() const override {
return &Delete;
}
private:
PrimitiveDataType()
: PrimitiveDataTypeBase{sizeof(T),
utils::ToTensorProtoElementType<T>()} {
}
};
template <>
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<MLFloat16>() {
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
}
enum TensorProto_DataType : int {
TensorProto_DataType_UNDEFINED = 0,
TensorProto_DataType_FLOAT = 1,
TensorProto_DataType_UINT8 = 2,
TensorProto_DataType_INT8 = 3,
TensorProto_DataType_UINT16 = 4,
TensorProto_DataType_INT16 = 5,
TensorProto_DataType_INT32 = 6,
TensorProto_DataType_INT64 = 7,
TensorProto_DataType_STRING = 8,
TensorProto_DataType_BOOL = 9,
TensorProto_DataType_FLOAT16 = 10,
TensorProto_DataType_DOUBLE = 11,
TensorProto_DataType_UINT32 = 12,
TensorProto_DataType_UINT64 = 13,
TensorProto_DataType_COMPLEX64 = 14,
TensorProto_DataType_COMPLEX128 = 15,
TensorProto_DataType_BFLOAT16 = 16
};