Torch 自定义算子
上代码
从torch导出算子
class KDTreeFpsample(torch.autograd.Function):
@staticmethod
def symbolic(g, xyz: torch.Tensor, nsample: int):
return g.op('WX::KDTreeFpsample', xyz, nsample_i=nsample)
@staticmethod
def forward(ctx, xyz: torch.Tensor, nsample: int):
fps_idx = kdtree_fpsample(xyz, nsample)
return fps_idx
# 示例使用
# xyz (B x N x 3), nsample 采样点
# fps_idx (B x nsample) 采样后的索引
fps_idx = KDTreeFpsample.apply(xyz, nsample)
def symbolic(g, xyz: torch.Tensor, nsample: int):
和
forward(ctx, xyz: torch.Tensor, nsample: int):
对应的,xyz是输入,nsample就是一个标量数据
return g.op('WX::KDTreeFpsample', xyz, nsample_i=nsample)
这里的 WX 是一个防止算子名称重复的命名空间,KDTreeFpsample是算子名称,如果导出为onnx的话,在加载onnx模型时也需要自定义一个onnx的算子,onnx的算子也会是这个名称。
nsample_i=nsample
这里这么写的原因是nsample只是一个标量,就是一个常数,_i 表示是一个整数标量,这是表示这个算子有一个 nsample 的标量属性。具体的请参考别的资料,我也说不清楚。
这里可以继续看到后面加载onnx模型时,就知道什么意思了。
fps_idx = kdtree_fpsample(xyz, nsample)
这是一个使用C++扩展的函数
这种方式定义的算子不需要额外的去显示注册,你只要使用了
KDTreeFpsample.apply
导出为onnx的时候就会自动把这个算子加入的。
导出onnx模型的方式请参考逼得相关教程。
在onnxruntime的python版本里,加载有自定义算子的onnx模型
如果使用上面自定义的算子导出onnx模型,那加载这个onnx模型的时候,也需要实现和注册这个算子,以下是我自己使用的方式
需要 包 onnxruntime 和 onnxruntime_extensions
import onnxruntime as ort
from onnxruntime_extensions import onnx_op, PyCustomOpDef
from onnxruntime_extensions import get_library_path as _lib_path
@onnx_op(op_type="WX::KDTreeFpsample",
inputs=[PyCustomOpDef.dt_float], # 输入类型
outputs=[PyCustomOpDef.dt_int32], # 输出类型
attrs={'nsample': PyCustomOpDef.dt_int64} # 这里就是torch算子里定义的 nsample_i=nsample
)
def KDTreeFpsample(xyz: np.ndarray, **kwargs) -> np.ndarray:
"""
Args:
xyz : BxNxC
Returns:
_type_: Idx B x nsample
"""
nsample = kwargs.get('nsample', None)
fps_idx = kdtree_fpsample(xyz, nsample)
return fps_idx
so = ort.SessionOptions()
so.register_custom_ops_library(_lib_path())
# 然后就可以正常加载onnx模型了
# ......
fps_idx = kdtree_fpsample(xyz, nsample)
和上面的上面是一样的函数,是C++扩展的python函数
在onnxruntime的C++版本里,加载有自定义算子的onnx模型
onnxruntime的环境配置不用多说,也可以参考别的教程。
插件定义
base.h
文件
#pragma once
#include <cassert>
#include <iostream>
#include <stdexcept>
#include <onnxruntime/onnxruntime_lite_custom_op.h>
#define CHECK_ATTRIBUTE( attribute_name ) CheckAttribute( ort_api, info, #attribute_name, attribute_name )
template <typename>
constexpr bool always_false = false;
template <typename T>
static void CheckAttribute( const OrtApi* ort_api, const OrtKernelInfo* info, const char* name, T& attribute ) {
OrtStatus* status = nullptr;
if constexpr ( std::is_same<T, int64_t>::value ) {
status = ort_api->KernelInfoGetAttribute_int64( info, name, &attribute );
}
else if constexpr ( std::is_same<T, int32_t>::value ) {
// status = ort_api->KernelInfoGetAttribute_int64( info, name, &attribute );
}
else if constexpr ( std::is_same<T, float>::value ) {
status = ort_api->KernelInfoGetAttribute_float( info, name, &attribute );
}
else {
static_assert( always_false<T>, "Unsupported type for KernelInfoGetAttribute" );
}
#ifdef _DEBUG
assert( status == nullptr && ( std::string( "Attribute '" ) + name + "' not found" ).c_str() );
#else
if ( status != nullptr ) {
std::cerr << "Failed to retrieve '" << name << "' attribute in release mode." << std::endl;
ort_api->ReleaseStatus( status );
}
// else {
// std::cout << name << ": " << attribute << std::endl;
// }
#endif
}
新建头文件 FurthestPointSampling.h
#pragma once
#include "base.h"
#include <onnxruntime/onnxruntime_cxx_api.h>
struct FurthestPointSampling {
int64_t nsample; // 下采样点数,就是上面的 nsample
FurthestPointSampling( const OrtApi* ort_api, const OrtKernelInfo* info );
void Compute( const Ort::Custom::Tensor<float>& xyz, Ort::Custom::Tensor<int32_t>& out_fps_idx );
};
新建源文件 FurthestPointSampling.cpp
#include "FurthestPointSampling.h"
#include <omp.h>
#include "../kdtree_fpsample/kdtree_fpsample.h"
FurthestPointSampling::FurthestPointSampling( const OrtApi* ort_api, const OrtKernelInfo* info )
: nsample( 0 ) {
CHECK_ATTRIBUTE( nsample ); // 就是在python torch 算子里定义的nsample
}
void FurthestPointSampling::Compute( const Ort::Custom::Tensor<float>& xyz, Ort::Custom::Tensor<int32_t>& out_fps_idx ) {
/*
In xyz: (B, N, 3)
Out fps_idx: (B, nsample)
*/
std::vector<int64_t> input_shape = xyz.Shape();
int64_t B = input_shape[ 0 ];
int64_t N = input_shape[ 1 ];
int64_t D = input_shape[ 2 ];
const float* xyz_raw = xyz.Data();
std::vector<int64_t> output_shape = { B, nsample};
int32_t* fps_idx_raw = out_fps_idx.Allocate( output_shape );
std::fill_n( fps_idx_raw, B * nsample, 0 );
#pragma omp parallel for
for ( int ib = 0; ib < B; ++ib ) {
const float* batch_xyz = xyz_raw + ib * N * D;
wx::kdline_fpsample( batch_xyz, N, nsample, fps_idx_raw + ib * nsample );
}
}
注册插件
/// 注册自定义算子
#include "./custom_ops/FurthestPointSampling.h"
void RegisterOps( Ort::CustomOpDomain& domain ) {
domain.Add( Ort::Custom::CreateLiteCustomOp<FurthestPointSampling>( "FurthestPointSampling", "CPUExecutionProvider" ) ); // 名字要对应
}
// 其余代码
// 注册
m_env = std::make_unique<Ort::Env>( ORT_LOGGING_LEVEL_WARNING, "Onn" );
// 创建会话选项
Ort::SessionOptions session_options;
session_options.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_ALL );
// 注册自定义算子
Ort::CustomOpDomain custom_domain{ "WX" };
RegisterOps( custom_domain );
session_options.Add( custom_domain );
// 其余代码
C++版本的自定义onnx算子就这样的。
在TensorRT的C++版本里,加载有自定义算子的onnx模型
见我的另一个博客文章