TensorRT 10.x 自定义插件

## 前言

新版本的 tensorrt 自定义插件有了新的方式,再之前的版本,网上资料都说要从头编译整个tensorrt,或者就是太麻烦了,而且 文档像一坨屎一样难看。

更新:

写了一个简单的示例,仅供参考:
make-a/tensorrt-plugin-example

## 自定义插件

### 1,先下载新版本的tensorrt 10.x库,这有很多教程

TensorRT 10.x Download | NVIDIA Developer

然后就是配置环境,这也有很多教程,我这里是 vs2022 加 TensorRT-10.6.0.26

### 2,新建插件类,实现插件接口

新版的插件需要实现新的插件接口:
 

#include <NvInferPlugin.h>
#include <cuda_runtime_api.h>

namespace nvinfer1 {

class FpsamplePlugin : public nvinfer1::IPluginV3,
                       public nvinfer1::IPluginV3OneCore,
                       public nvinfer1::IPluginV3OneBuild,
                       public nvinfer1::IPluginV3OneRuntime {

public:
    FpsamplePlugin( int32_t nsample );

我这里就直接贴出这个插件的完整代码了,可以参考;每个函数的具体含义,可以查看文档;这里仅仅是参考。
 

#pragma once
#include <string>
#include <vector>

#include <NvInferPlugin.h>
#include <cuda_runtime_api.h>

using namespace nvinfer1;

namespace nvinfer1 {

class FpsamplePlugin : public nvinfer1::IPluginV3,
                       public nvinfer1::IPluginV3OneCore,
                       public nvinfer1::IPluginV3OneBuild,
                       public nvinfer1::IPluginV3OneRuntime {

public:
    FpsamplePlugin( int32_t nsample );

    // 通过 IPluginV3 继承
    IPluginCapability* getCapabilityInterface( PluginCapabilityType type ) noexcept override;
    IPluginV3*         clone() noexcept override;

    // 通过 IPluginV3OneCore 继承
    AsciiChar const* getPluginName() const noexcept override;
    AsciiChar const* getPluginVersion() const noexcept override;
    AsciiChar const* getPluginNamespace() const noexcept override;

    // 通过 IPluginV3OneBuild 继承
    int32_t configurePlugin( DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out,
                             int32_t nbOutputs ) noexcept override;
    int32_t getOutputDataTypes( DataType* outputTypes, int32_t nbOutputs, const DataType* inputTypes,
                                int32_t nbInputs ) const noexcept override;
    int32_t getOutputShapes( DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
                             int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs,
                             IExprBuilder& exprBuilder ) noexcept override;
    size_t  getWorkspaceSize( DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
                              DynamicPluginTensorDesc const* outputs, int32_t nbOutputs ) const noexcept override;
    bool    supportsFormatCombination( int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs,
                                       int32_t nbOutputs ) noexcept override;
    int32_t getNbOutputs() const noexcept override;

    // 通过 IPluginV3OneRuntime 继承
    int32_t onShapeChange( PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out,
                           int32_t nbOutputs ) noexcept override;
    int32_t enqueue( PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs,
                     void* const* outputs, void* workspace, cudaStream_t stream ) noexcept override;
    IPluginV3*                   attachToContext( IPluginResourceContext* context ) noexcept override;
    PluginFieldCollection const* getFieldsToSerialize() noexcept override;

private:
    int32_t                  m_nsample;
    PluginFieldCollection    m_pfc;
    std::vector<PluginField> m_pluginAttributes;

    int32_t m_in_n  = 1;
    int32_t m_out_n = 1;
};

class FpsamplePluginCreator : public nvinfer1::IPluginCreatorV3One {
public:
    FpsamplePluginCreator();

    // 通过 IPluginCreatorV3One 继承
    IPluginV3* createPlugin( AsciiChar const* name, PluginFieldCollection const* fc,
                             TensorRTPhase phase ) noexcept override;

    PluginFieldCollection const* getFieldNames() noexcept override;
    AsciiChar const*             getPluginName() const noexcept override;
    AsciiChar const*             getPluginVersion() const noexcept override;
    AsciiChar const*             getPluginNamespace() const noexcept override;

private:
    PluginFieldCollection    m_pfc;
    std::vector<PluginField> m_pluginAttributes;
};
}  // namespace nvinfer1
#include "FpsamplePlugin.h"

#include <cassert>
#include <iostream>
#include <memory>

#include "../cuda_impl/cuda_impl.h"
#include "../kdtree_fpsample/kdtree_fpsample.h"

namespace {
char const* const FPSAMPLE_PLUGIN_VERSION{ "1" };
char const* const FPSAMPLE_PLUGIN_NAME{ "KDTreeFpsample" };
char const* const FPSAMPLE_PLUGIN_NAMESPACE{ "" };
}  // namespace

nvinfer1::FpsamplePlugin::FpsamplePlugin( int32_t nsample )
    : m_nsample( nsample ) {
    m_pluginAttributes.clear();

    PluginField pf_nsample = { "nsample", &m_nsample, PluginFieldType::kINT32, 1 };
    m_pluginAttributes.emplace_back( pf_nsample );

    m_pfc.nbFields = m_pluginAttributes.size();
    m_pfc.fields   = m_pluginAttributes.data();
};

IPluginCapability* FpsamplePlugin::getCapabilityInterface( PluginCapabilityType type ) noexcept {
    try {
        if ( type == PluginCapabilityType::kBUILD ) {
            return static_cast<IPluginV3OneBuild*>( this );
        }
        if ( type == PluginCapabilityType::kRUNTIME ) {
            return static_cast<IPluginV3OneRuntime*>( this );
        }
        assert( type == PluginCapabilityType::kCORE );
        return static_cast<IPluginV3OneCore*>( this );
    }
    catch ( ... ) {
        // log error
    }
    return nullptr;
}

IPluginV3* FpsamplePlugin::clone() noexcept {
    return new FpsamplePlugin( m_nsample );
}

AsciiChar const* FpsamplePlugin::getPluginName() const noexcept {
    return FPSAMPLE_PLUGIN_NAME;
}

AsciiChar const* FpsamplePlugin::getPluginVersion() const noexcept {
    return FPSAMPLE_PLUGIN_VERSION;
}

AsciiChar const* FpsamplePlugin::getPluginNamespace() const noexcept {
    return FPSAMPLE_PLUGIN_NAMESPACE;
}

int32_t FpsamplePlugin::configurePlugin( DynamicPluginTensorDesc const* in, int32_t nbInputs,
                                         DynamicPluginTensorDesc const* out, int32_t nbOutputs ) noexcept {
    return 0;
}

int32_t FpsamplePlugin::getOutputDataTypes( DataType* outputTypes, int32_t nbOutputs, const DataType* inputTypes,
                                            int32_t nbInputs ) const noexcept {
    // 确保输入输出数量符合预期
    assert( nbInputs == m_in_n );
    assert( nbOutputs == m_out_n );

    // 检查输入类型是否为 float
    assert( inputTypes[ 0 ] == DataType::kFLOAT );

    // 设置输出类型为 int64
    outputTypes[ 0 ] = DataType::kINT64;

    return 0;
}

int32_t FpsamplePlugin::getOutputShapes( const DimsExprs* inputs, int32_t nbInputs, const DimsExprs* shapeInputs,
                                         int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs,
                                         IExprBuilder& exprBuilder ) noexcept {
    // 确保输入输出数量符合预期
    assert( nbInputs == m_in_n );
    assert( nbOutputs == m_out_n );

    // 输入张量的形状
    const DimsExprs& inputShape = inputs[ 0 ];

    // 设置输出张量的形状,假设为 [BATCH, NUM_POINTS]
    outputs[ 0 ].nbDims = 2;                                  // 输出维度数为 2
    outputs[ 0 ].d[ 0 ] = inputShape.d[ 0 ];                  // BATCH
    outputs[ 0 ].d[ 1 ] = exprBuilder.constant( m_nsample );  // NUM_POINTS

    return 0;  // 返回 0 表示成功
}

size_t nvinfer1::FpsamplePlugin::getWorkspaceSize( DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
                                                   DynamicPluginTensorDesc const* outputs,
                                                   int32_t                        nbOutputs ) const noexcept {
    const int32_t B = inputs[ 0 ].desc.dims.d[ 0 ];
    const int32_t N = inputs[ 0 ].desc.dims.d[ 1 ];

    return B * N * sizeof( float );
}

bool FpsamplePlugin::supportsFormatCombination( int32_t pos, const DynamicPluginTensorDesc* inOut, int32_t nbInputs,
                                                int32_t nbOutputs ) noexcept {
    // 确保 pos 在合法范围内
    assert( pos < ( nbInputs + nbOutputs ) );

    // 检查输入
    if ( pos == 0 ) {
        // 输入必须是 float 类型,且格式为线性
        return inOut[ pos ].desc.type == DataType::kFLOAT && inOut[ pos ].desc.format == TensorFormat::kLINEAR;
    }
    // 检查输出
    else if ( pos == 1 ) {
        // 输出必须是 int64 类型,且格式为线性
        return inOut[ pos ].desc.type == DataType::kINT64 && inOut[ pos ].desc.format == TensorFormat::kLINEAR;
    }

    return false;  // 其他情况不支持
}

int32_t FpsamplePlugin::getNbOutputs() const noexcept {
    return m_out_n;
}

int32_t FpsamplePlugin::onShapeChange( PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out,
                                       int32_t nbOutputs ) noexcept {
    return 0;
}

int32_t FpsamplePlugin::enqueue( PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc,
                                 void const* const* inputs, void* const* outputs, void* workspace,
                                 cudaStream_t stream ) noexcept {

    const float* xyz  = static_cast<const float*>( inputs[ 0 ] );
    int64_t*     idxs = static_cast<int64_t*>( outputs[ 0 ] );

    // 从PluginTensorDesc中提取维度信息
    int B = inputDesc[ 0 ].dims.d[ 0 ];
    int N = inputDesc[ 0 ].dims.d[ 1 ];
    int C = inputDesc[ 0 ].dims.d[ 2 ];

    int S = outputDesc[ 0 ].dims.d[ 1 ];

    // 计算所需的内存大小N
    size_t temp_size      = B * N;
    size_t temp_size_byte = temp_size * sizeof( float );

    if ( workspace == nullptr ) {
        return -1;
    }
    cudaError_t err = cudaMemset( workspace, 10000, temp_size_byte );
    if ( err != cudaSuccess ) {
        return -1;
    }

    // 调用内核函数
    furthest_point_sampling_kernel_wrapper( B, N, S, xyz, static_cast<float*>( workspace ), idxs, stream );

    return 0;
}

IPluginV3* FpsamplePlugin::attachToContext( IPluginResourceContext* context ) noexcept {
    return clone();
}

PluginFieldCollection const* FpsamplePlugin::getFieldsToSerialize() noexcept {
    return &m_pfc;
}

/// /
/// FpsamplePluginCreator

FpsamplePluginCreator::FpsamplePluginCreator() {
    m_pluginAttributes.clear();

    PluginField pf_nsample = { "nsample", nullptr, PluginFieldType::kINT32, 1 };
    m_pluginAttributes.emplace_back( pf_nsample );

    m_pfc.nbFields = m_pluginAttributes.size();
    m_pfc.fields   = m_pluginAttributes.data();
}

IPluginV3* FpsamplePluginCreator::createPlugin( AsciiChar const* name, PluginFieldCollection const* fc,
                                                TensorRTPhase phase ) noexcept {
    assert( fc->nbFields == 1 );
    assert( fc->fields[ 0 ].type == PluginFieldType::kINT32 );
    fc->fields[ 0 ].name;

    FpsamplePlugin* plugin = new FpsamplePlugin( *static_cast<int32_t const*>( fc->fields[ 0 ].data ) );

    return plugin;
}

PluginFieldCollection const* FpsamplePluginCreator::getFieldNames() noexcept {
    return &m_pfc;
}

AsciiChar const* FpsamplePluginCreator::getPluginName() const noexcept {
    return FPSAMPLE_PLUGIN_NAME;
}

AsciiChar const* FpsamplePluginCreator::getPluginVersion() const noexcept {
    return FPSAMPLE_PLUGIN_VERSION;
}

AsciiChar const* FpsamplePluginCreator::getPluginNamespace() const noexcept {
    return FPSAMPLE_PLUGIN_NAMESPACE;
}

### 3,注册插件

插件的注册也简单:

    auto pluginCreator1 = std::make_unique<FpsamplePluginCreator>();
    getPluginRegistry()->registerCreator( *pluginCreator1.get(), "" );

    auto pluginCreator2 = std::make_unique<BallQueryPluginCreator>();
    getPluginRegistry()->registerCreator( *pluginCreator2.get(), "" );

需要注意的是,插件注册需要再加载 engine 文件之前,也就是在tensorrt初始化之前;

### 4,最后

插件注册的问题,一般,engine 文件都是从onnx文件转换过来的,onnx文件又是从torch转换而来;
比如这里的这个插件,是最远点采样的;
在从torch模型转换为onnx时,如果fps采样才模型内部使用,会导致导出失败或者导出特别慢,所以在torch导出模型时就要注册自定义算子也就是fps采样的算子,像这样:

class KDTreeFpsample(torch.autograd.Function):
    @staticmethod
    def symbolic(g, xyz: torch.Tensor, nsample: int):
        return g.op('WX::KDTreeFpsample', xyz, nsample_i=nsample)

    @staticmethod
    def forward(ctx, xyz: torch.Tensor, nsample: int):
        fps_idx = kdtree_fpsample(xyz, nsample)
        return fps_idx

当然这只是我这里使用的方式,别的我也不知道,这里的kdtree_fpsample是封装的C++函数;
这也可以是普通的函数,毕竟这里只是为了导出为onnx模型。
使用方式就是,在模型的需要使用的地方,forward函数里:

def sample_and_group_batch(npoint, radius, nsample, xyz, points):
    B, N, C = xyz.shape
    S = npoint

    fps_idx = KDTreeFpsample.apply(xyz, npoint)
    new_xyz = index_points(xyz, fps_idx)

    neighbors = BallQueryBatch.apply(xyz, new_xyz, radius, nsample)
    grouped_xyz = neighbors[:, :, :, :3]  # [B, npoint, nsample, 3]
    grouped_idx = neighbors[:, :, :, 3].long()  # [B, npoint, nsample]
    # grouped_xyz, grouped_idx = BallQueryBatch.apply(xyz, new_xyz, radius, nsample)
    grouped_xyz = grouped_xyz.to(xyz.device)
    grouped_idx = grouped_idx.to(xyz.device)

    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)

    grouped_points = index_points(points, grouped_idx)
    # [B, npoint, nsample, C+D]
    new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1)

    return new_xyz, new_points

......

# forward 函数里

class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        # if self.group_all:
        #     new_xyz, new_points = sample_and_group_all(xyz, points)  # 这里不执行
        # else:
        #     new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        new_xyz, new_points = sample_and_group_batch(self.npoint, self.radius, self.nsample, xyz, points)


 

### 5,最最后

这里我遇到了很多坑,包括torch自定义算子,从torch导出onnx,onnx自定义算子,onnx导出engine,tensorrt自定义插件,等等.....

不过这些坑我都在摸索中解决了,当然现在也差不多忘记了,可以交流讨论,如果我还记得的话。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值