## 前言
新版本的 tensorrt 自定义插件有了新的方式,再之前的版本,网上资料都说要从头编译整个tensorrt,或者就是太麻烦了,而且 文档像一坨屎一样难看。
更新:
写了一个简单的示例,仅供参考:
make-a/tensorrt-plugin-example
## 自定义插件
### 1,先下载新版本的tensorrt 10.x库,这有很多教程
TensorRT 10.x Download | NVIDIA Developer
然后就是配置环境,这也有很多教程,我这里是 vs2022 加 TensorRT-10.6.0.26
### 2,新建插件类,实现插件接口
新版的插件需要实现新的插件接口:
#include <NvInferPlugin.h>
#include <cuda_runtime_api.h>
namespace nvinfer1 {
class FpsamplePlugin : public nvinfer1::IPluginV3,
public nvinfer1::IPluginV3OneCore,
public nvinfer1::IPluginV3OneBuild,
public nvinfer1::IPluginV3OneRuntime {
public:
FpsamplePlugin( int32_t nsample );
我这里就直接贴出这个插件的完整代码了,可以参考;每个函数的具体含义,可以查看文档;这里仅仅是参考。
#pragma once
#include <string>
#include <vector>
#include <NvInferPlugin.h>
#include <cuda_runtime_api.h>
using namespace nvinfer1;
namespace nvinfer1 {
class FpsamplePlugin : public nvinfer1::IPluginV3,
public nvinfer1::IPluginV3OneCore,
public nvinfer1::IPluginV3OneBuild,
public nvinfer1::IPluginV3OneRuntime {
public:
FpsamplePlugin( int32_t nsample );
// 通过 IPluginV3 继承
IPluginCapability* getCapabilityInterface( PluginCapabilityType type ) noexcept override;
IPluginV3* clone() noexcept override;
// 通过 IPluginV3OneCore 继承
AsciiChar const* getPluginName() const noexcept override;
AsciiChar const* getPluginVersion() const noexcept override;
AsciiChar const* getPluginNamespace() const noexcept override;
// 通过 IPluginV3OneBuild 继承
int32_t configurePlugin( DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out,
int32_t nbOutputs ) noexcept override;
int32_t getOutputDataTypes( DataType* outputTypes, int32_t nbOutputs, const DataType* inputTypes,
int32_t nbInputs ) const noexcept override;
int32_t getOutputShapes( DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs,
IExprBuilder& exprBuilder ) noexcept override;
size_t getWorkspaceSize( DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
DynamicPluginTensorDesc const* outputs, int32_t nbOutputs ) const noexcept override;
bool supportsFormatCombination( int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs,
int32_t nbOutputs ) noexcept override;
int32_t getNbOutputs() const noexcept override;
// 通过 IPluginV3OneRuntime 继承
int32_t onShapeChange( PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out,
int32_t nbOutputs ) noexcept override;
int32_t enqueue( PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream ) noexcept override;
IPluginV3* attachToContext( IPluginResourceContext* context ) noexcept override;
PluginFieldCollection const* getFieldsToSerialize() noexcept override;
private:
int32_t m_nsample;
PluginFieldCollection m_pfc;
std::vector<PluginField> m_pluginAttributes;
int32_t m_in_n = 1;
int32_t m_out_n = 1;
};
class FpsamplePluginCreator : public nvinfer1::IPluginCreatorV3One {
public:
FpsamplePluginCreator();
// 通过 IPluginCreatorV3One 继承
IPluginV3* createPlugin( AsciiChar const* name, PluginFieldCollection const* fc,
TensorRTPhase phase ) noexcept override;
PluginFieldCollection const* getFieldNames() noexcept override;
AsciiChar const* getPluginName() const noexcept override;
AsciiChar const* getPluginVersion() const noexcept override;
AsciiChar const* getPluginNamespace() const noexcept override;
private:
PluginFieldCollection m_pfc;
std::vector<PluginField> m_pluginAttributes;
};
} // namespace nvinfer1
#include "FpsamplePlugin.h"
#include <cassert>
#include <iostream>
#include <memory>
#include "../cuda_impl/cuda_impl.h"
#include "../kdtree_fpsample/kdtree_fpsample.h"
namespace {
char const* const FPSAMPLE_PLUGIN_VERSION{ "1" };
char const* const FPSAMPLE_PLUGIN_NAME{ "KDTreeFpsample" };
char const* const FPSAMPLE_PLUGIN_NAMESPACE{ "" };
} // namespace
nvinfer1::FpsamplePlugin::FpsamplePlugin( int32_t nsample )
: m_nsample( nsample ) {
m_pluginAttributes.clear();
PluginField pf_nsample = { "nsample", &m_nsample, PluginFieldType::kINT32, 1 };
m_pluginAttributes.emplace_back( pf_nsample );
m_pfc.nbFields = m_pluginAttributes.size();
m_pfc.fields = m_pluginAttributes.data();
};
IPluginCapability* FpsamplePlugin::getCapabilityInterface( PluginCapabilityType type ) noexcept {
try {
if ( type == PluginCapabilityType::kBUILD ) {
return static_cast<IPluginV3OneBuild*>( this );
}
if ( type == PluginCapabilityType::kRUNTIME ) {
return static_cast<IPluginV3OneRuntime*>( this );
}
assert( type == PluginCapabilityType::kCORE );
return static_cast<IPluginV3OneCore*>( this );
}
catch ( ... ) {
// log error
}
return nullptr;
}
IPluginV3* FpsamplePlugin::clone() noexcept {
return new FpsamplePlugin( m_nsample );
}
AsciiChar const* FpsamplePlugin::getPluginName() const noexcept {
return FPSAMPLE_PLUGIN_NAME;
}
AsciiChar const* FpsamplePlugin::getPluginVersion() const noexcept {
return FPSAMPLE_PLUGIN_VERSION;
}
AsciiChar const* FpsamplePlugin::getPluginNamespace() const noexcept {
return FPSAMPLE_PLUGIN_NAMESPACE;
}
int32_t FpsamplePlugin::configurePlugin( DynamicPluginTensorDesc const* in, int32_t nbInputs,
DynamicPluginTensorDesc const* out, int32_t nbOutputs ) noexcept {
return 0;
}
int32_t FpsamplePlugin::getOutputDataTypes( DataType* outputTypes, int32_t nbOutputs, const DataType* inputTypes,
int32_t nbInputs ) const noexcept {
// 确保输入输出数量符合预期
assert( nbInputs == m_in_n );
assert( nbOutputs == m_out_n );
// 检查输入类型是否为 float
assert( inputTypes[ 0 ] == DataType::kFLOAT );
// 设置输出类型为 int64
outputTypes[ 0 ] = DataType::kINT64;
return 0;
}
int32_t FpsamplePlugin::getOutputShapes( const DimsExprs* inputs, int32_t nbInputs, const DimsExprs* shapeInputs,
int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs,
IExprBuilder& exprBuilder ) noexcept {
// 确保输入输出数量符合预期
assert( nbInputs == m_in_n );
assert( nbOutputs == m_out_n );
// 输入张量的形状
const DimsExprs& inputShape = inputs[ 0 ];
// 设置输出张量的形状,假设为 [BATCH, NUM_POINTS]
outputs[ 0 ].nbDims = 2; // 输出维度数为 2
outputs[ 0 ].d[ 0 ] = inputShape.d[ 0 ]; // BATCH
outputs[ 0 ].d[ 1 ] = exprBuilder.constant( m_nsample ); // NUM_POINTS
return 0; // 返回 0 表示成功
}
size_t nvinfer1::FpsamplePlugin::getWorkspaceSize( DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
DynamicPluginTensorDesc const* outputs,
int32_t nbOutputs ) const noexcept {
const int32_t B = inputs[ 0 ].desc.dims.d[ 0 ];
const int32_t N = inputs[ 0 ].desc.dims.d[ 1 ];
return B * N * sizeof( float );
}
bool FpsamplePlugin::supportsFormatCombination( int32_t pos, const DynamicPluginTensorDesc* inOut, int32_t nbInputs,
int32_t nbOutputs ) noexcept {
// 确保 pos 在合法范围内
assert( pos < ( nbInputs + nbOutputs ) );
// 检查输入
if ( pos == 0 ) {
// 输入必须是 float 类型,且格式为线性
return inOut[ pos ].desc.type == DataType::kFLOAT && inOut[ pos ].desc.format == TensorFormat::kLINEAR;
}
// 检查输出
else if ( pos == 1 ) {
// 输出必须是 int64 类型,且格式为线性
return inOut[ pos ].desc.type == DataType::kINT64 && inOut[ pos ].desc.format == TensorFormat::kLINEAR;
}
return false; // 其他情况不支持
}
int32_t FpsamplePlugin::getNbOutputs() const noexcept {
return m_out_n;
}
int32_t FpsamplePlugin::onShapeChange( PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out,
int32_t nbOutputs ) noexcept {
return 0;
}
int32_t FpsamplePlugin::enqueue( PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream ) noexcept {
const float* xyz = static_cast<const float*>( inputs[ 0 ] );
int64_t* idxs = static_cast<int64_t*>( outputs[ 0 ] );
// 从PluginTensorDesc中提取维度信息
int B = inputDesc[ 0 ].dims.d[ 0 ];
int N = inputDesc[ 0 ].dims.d[ 1 ];
int C = inputDesc[ 0 ].dims.d[ 2 ];
int S = outputDesc[ 0 ].dims.d[ 1 ];
// 计算所需的内存大小N
size_t temp_size = B * N;
size_t temp_size_byte = temp_size * sizeof( float );
if ( workspace == nullptr ) {
return -1;
}
cudaError_t err = cudaMemset( workspace, 10000, temp_size_byte );
if ( err != cudaSuccess ) {
return -1;
}
// 调用内核函数
furthest_point_sampling_kernel_wrapper( B, N, S, xyz, static_cast<float*>( workspace ), idxs, stream );
return 0;
}
IPluginV3* FpsamplePlugin::attachToContext( IPluginResourceContext* context ) noexcept {
return clone();
}
PluginFieldCollection const* FpsamplePlugin::getFieldsToSerialize() noexcept {
return &m_pfc;
}
/// /
/// FpsamplePluginCreator
FpsamplePluginCreator::FpsamplePluginCreator() {
m_pluginAttributes.clear();
PluginField pf_nsample = { "nsample", nullptr, PluginFieldType::kINT32, 1 };
m_pluginAttributes.emplace_back( pf_nsample );
m_pfc.nbFields = m_pluginAttributes.size();
m_pfc.fields = m_pluginAttributes.data();
}
IPluginV3* FpsamplePluginCreator::createPlugin( AsciiChar const* name, PluginFieldCollection const* fc,
TensorRTPhase phase ) noexcept {
assert( fc->nbFields == 1 );
assert( fc->fields[ 0 ].type == PluginFieldType::kINT32 );
fc->fields[ 0 ].name;
FpsamplePlugin* plugin = new FpsamplePlugin( *static_cast<int32_t const*>( fc->fields[ 0 ].data ) );
return plugin;
}
PluginFieldCollection const* FpsamplePluginCreator::getFieldNames() noexcept {
return &m_pfc;
}
AsciiChar const* FpsamplePluginCreator::getPluginName() const noexcept {
return FPSAMPLE_PLUGIN_NAME;
}
AsciiChar const* FpsamplePluginCreator::getPluginVersion() const noexcept {
return FPSAMPLE_PLUGIN_VERSION;
}
AsciiChar const* FpsamplePluginCreator::getPluginNamespace() const noexcept {
return FPSAMPLE_PLUGIN_NAMESPACE;
}
### 3,注册插件
插件的注册也简单:
auto pluginCreator1 = std::make_unique<FpsamplePluginCreator>();
getPluginRegistry()->registerCreator( *pluginCreator1.get(), "" );
auto pluginCreator2 = std::make_unique<BallQueryPluginCreator>();
getPluginRegistry()->registerCreator( *pluginCreator2.get(), "" );
需要注意的是,插件注册需要再加载 engine 文件之前,也就是在tensorrt初始化之前;
### 4,最后
插件注册的问题,一般,engine 文件都是从onnx文件转换过来的,onnx文件又是从torch转换而来;
比如这里的这个插件,是最远点采样的;
在从torch模型转换为onnx时,如果fps采样才模型内部使用,会导致导出失败或者导出特别慢,所以在torch导出模型时就要注册自定义算子也就是fps采样的算子,像这样:
class KDTreeFpsample(torch.autograd.Function):
@staticmethod
def symbolic(g, xyz: torch.Tensor, nsample: int):
return g.op('WX::KDTreeFpsample', xyz, nsample_i=nsample)
@staticmethod
def forward(ctx, xyz: torch.Tensor, nsample: int):
fps_idx = kdtree_fpsample(xyz, nsample)
return fps_idx
当然这只是我这里使用的方式,别的我也不知道,这里的kdtree_fpsample是封装的C++函数;
这也可以是普通的函数,毕竟这里只是为了导出为onnx模型。
使用方式就是,在模型的需要使用的地方,forward函数里:
def sample_and_group_batch(npoint, radius, nsample, xyz, points):
B, N, C = xyz.shape
S = npoint
fps_idx = KDTreeFpsample.apply(xyz, npoint)
new_xyz = index_points(xyz, fps_idx)
neighbors = BallQueryBatch.apply(xyz, new_xyz, radius, nsample)
grouped_xyz = neighbors[:, :, :, :3] # [B, npoint, nsample, 3]
grouped_idx = neighbors[:, :, :, 3].long() # [B, npoint, nsample]
# grouped_xyz, grouped_idx = BallQueryBatch.apply(xyz, new_xyz, radius, nsample)
grouped_xyz = grouped_xyz.to(xyz.device)
grouped_idx = grouped_idx.to(xyz.device)
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
grouped_points = index_points(points, grouped_idx)
# [B, npoint, nsample, C+D]
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1)
return new_xyz, new_points
......
# forward 函数里
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
# if self.group_all:
# new_xyz, new_points = sample_and_group_all(xyz, points) # 这里不执行
# else:
# new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
new_xyz, new_points = sample_and_group_batch(self.npoint, self.radius, self.nsample, xyz, points)
### 5,最最后
这里我遇到了很多坑,包括torch自定义算子,从torch导出onnx,onnx自定义算子,onnx导出engine,tensorrt自定义插件,等等.....
不过这些坑我都在摸索中解决了,当然现在也差不多忘记了,可以交流讨论,如果我还记得的话。