TensorRT的自定义算子Plugin的实现

本文详细介绍了如何在TensorRT 7.0中实现自定义算子,以解决Caffe模型中Upsample层的兼容性问题。通过继承IPluginV2IOExt、IPluginCreator及IPluginFactoryV2等接口,文章提供了实现自定义算子的具体步骤和关键代码示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这篇文章主要介绍了如何使用TensorRT实现自定义算子。

Note:

  1. 我使用的是TensorRT7.0,自定义算子使用的IPluginV2IOExt实现的。
  2. 模型框架是caffe,所以以下实现都只适用于caffe模型的解析,但理论上解析tf和onnx的改动不大。
  3. 实现细节不方便全部贴出,但是基本实现过程和结构都在下面了,照着写写没啥问题了。

其实自定义算子写多了发现其实还挺好写的,格式都差不多,主要区别是enqueue的前向计算逻辑可能写起来复杂些。
整个实现过程基本上是:

  1. 继承nvinfer1::IPluginV2IOExt,并实现相应的虚函数。
  2. 继承nvinfer1::IPluginCreator并实现相应的虚函数。
  3. 继承nvcaffeparser1::IPluginFactoryV2并实现相应的虚函数。
  4. 在解析网络之前调用REGISTER_TENSORRT_PLUGIN注册UpsampleCreator和调用parser->setPluginFactoryV2()以使用自定义层类型。

以Upsample为例,TensorRT不支持Caffe的Upsample层,所以这里实现了一个自定义层类型,即plugin。需要实现:

  1. Upsample类,继承自nvinfer1::IPluginV2IOExt。
  2. UpsampleCreator类,继承自nvinfer1::IPluginCreator。
  3. CaffePluginFactory类,继承自nvcaffeparser1::IPluginFactoryV2。

需要实现的函数详见如下代码段。

Upsample类的实现:

class Upsample : public nvinfer1::IPluginV2IOExt {
public:
    // 直接解析网络时候需要用到
    Upsample();
    // 反序列化时候需要用到
    Upsample(const void *data, size_t length);
    ~Upsample();
    
    // 直接return输出节点数,
    int getNbOutputs() override;
    
    // return输出的维度信息,如:return Dims3(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
    Dims getOutputDimensions(int index, const Dims *inputs, int num_input_dims) override;
    
    // pos索引到的input/output的数据格式(format)和数据类型(datatype)如果都支持则返回true
    bool supportsFormatCombination(int pos, const PluginTensorDesc* in_out, int num_inputs, int num_outputs) const override;
    
    // 这个函数可以获取到数据类型和输入的维度信息,如果有需要用到的可以在这里将相关信息取出来
    configurePlugin(const PluginTensorDesc* in, int num_inputs, const PluginTensorDesc* out, int num_outputs) override;

    // 在这里返回正确的序列化数据的长度,如我要序列化数据类型和数据维度:return sizeof(data_type) + sizeof(chw);
    size_t getSerializationSize() const override;
    
    // 序列化函数,在这里把反序列化时需要用到的参数或数据序列化
    void serialize(void *buffer) const override;
    
    // 设置工作空间,不需要直接 return 0;
    size_t getWorkspaceSize(int max_batch_size) const override;
    
    // 前向计算的核心函数,计算逻辑在这里实现,可以使用cublas实现或者自己写cuda核函数实现
    int enqueue(int batch_size, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) override;
    
    // 调用enqueue的时候需要用到的资源先在这里Initialize,这个函数是在engine创建之后enqueue调用之前调用的,不需要Initialize则直接 return 0;
    int initialize() override;
    
    // 释放Initialize申请的资源,在enqueue调用之后且engine销毁之后调用
    void terminate() override;
    
    // 返回输出的数据类型,如何输入相同,可以直接 return input_types[0];
    nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* input_types, int num_inputs) const override;
    
    // 返回自定义类型,如这里是:return Upsample
    const char* getPluginType() const override;
    
    // 返回plugin version,没啥说的
    const char* getPluginVersion() const override;
      
    // 销毁对象
    void destroy() override {
        delete this;
    }
    
    // 在这里new一个该自定义类型并返回
    nvinfer1::IPluginV2Ext* clone() const override;
    
    // 设置命名空间,用来在网络中查找和创建plugin
    void setPluginNamespace(const char* lib_namespace) override;
    // 返回plugin对象的命名空间
    const char* getPluginNamespace() const override;
    bool isOutputBroadcastAcrossBatch(int output_index, const bool* input_is_broadcasted, int num_inputs) const override;
    bool canBroadcastInputAcrossBatch(int input_index) const override;
}

下面是对应的Creator类的实现

class UpsampleCreator : public nvinfer1::IPluginCreator {
public:
    const char* getPluginName() const override;
    const char* getPluginVersion() const override;
    const PluginFieldCollection* getFieldNames() override;
    // 创建自定义层pluin的对象并返回
    nvinfer1::IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override;
    // 创建自定义层pluin的对象并返回,反序列化用到
    nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serial_data, size_t serial_length) override;
    void setPluginNamespace(const char* lib_namespace) override;
    const char* getPluginNamespace() const override;
}

下面是对应的plugin factory类的实现

class CaffePluginFactory : public nvcaffeparser1::IPluginFactoryV2 {
public:
    // 在这里判断一个层是否为自定义层类型
    bool isPluginV2(const char* name) override;
    // 在这里创建自定义层类型的对象并返回
    nvinfer1::IPluginV2* createPlugin(const char* layer_name, const nvinfer1::Weights* weights, int num_weights, const char* libNamespace="") override;
}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值