TensorRT的plugin实现

1. 总体概述(以实现add一个整数为例)

编写custom插件需要写两个类,分别如下:

  • AddPlugin:继承IPluginV2IOExt,插件类,用于编写插件需要实现的功能

  • AddPluginCreator:继承IPluginCreator,插件Factory类,用于创建插件

class AddPlugin: public nvinfer1::IPluginV2IOExt

class AddPluginCreator : public nvinfer1::IPluginCreator

后续工作:

  1. 将插件添加到TensorRT-OSS

  2. 将插件添加到onnx-tensorrt

2. AddPlugin类的实现

参考链接:TensorRT: nvinfer1::IPluginV2IOExt Class Reference (nvidia.com)

构造函数和析构函数

构造函数一般需要实现两个:

  • 第一个用于在创建plugin的过程,此时PluginCreator的createPlugin成员函数会调用

AddPlugin(nvinfer1::Weights valueToAdd)
  • 第二个用于Plugin类的clone成员函数,PluginCreator的deserializePlugin成员函数

AddPlugin(const void *buffer, size_t length)

同时,需要禁用默认构造函数

AddPlugin() = delete
析构函数用于释放该plugin之前开辟的显存空间
~AddPlugin() {}

四个重要的成员函数

  • getOutputDimensions

TensorRT支持Dynamic-Shape时,batch这一维度必须是explicit的,也就是说,TensorRT处理的维度从以往的三维【3,-1,-1】变成了【1,3,-1,-1】。 根据输入的维度推导出该plugin输出的维度。

nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* pInputDim, int nInputDim) override 
{ 
    /* get the dimension of an output tensor.
     * index: the index of the output tensor. 
    * pInputDim: the input tensors. 
    * nInputDim: the number of input tensors. */ 
    return pInputDim[0]; 
}
  • supportsFormatCombination

判断pos索引的输入/输出数据是否符合指定的format格式和type数据类型。

bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override 
    {
        /*
        return true if plugin supports the format and datatype for the input/output indexed by pos.
        * inputs are numbers [0, nbInputs-1]
        * outputs are numbers [nbInputs, nbInputs+nbOutputs-1]
        * 
        */
    	switch(pos) {
    	case 0:
    		printf("inOut[0].type = %d, format[0]=%d\n", (int)inOut[0].type, (int)inOut[0].format);
    		return 
    			((inOut[0].type == nvinfer1::DataType::kFLOAT || inOut[0].type == nvinfer1::DataType::kHALF) && inOut[0].format == nvinfer1::TensorFormat::kLINEAR)
    			|| (inOut[0].type == nvinfer1::DataType::kINT8 && inOut[0].format == nvinfer1::TensorFormat::kCHW4);
    	case 1:
    		printf("inOut[1].type = %d, format[1]=%d\n", (int)inOut[1].type, (int)inOut[1].format);
    		return inOut[0].format == inOut[1].format && inOut[0].type == inOut[1].type;
    	}
    	return false;
    }
  • configurePlugin

判断输入和输出类型,数量是否正确。

virtual void configurePlugin(const nvinfer1::PluginTensorDesc* in, int nbInput, const nvinfer1::PluginTensorDesc* out, int nbOutput) override 
    {
        /*
        fields that a plugin might see for an input or output.
        * scale is only valid when datatype is DataType::kINT8.
        * TensorRT will set the value to -1.0f if it is invalid.
        */
    	m.dataType = in[0].type;
    	m.inputDim = in[0].dims;
    	m.scale = in[0].scale;
    	printf("configurePlugin type=%d, m.scale=%f\n", (int)out[0].type, m.scale);
    }
  • enqueue

该plugin功能实现的接口,功能实现的cuda或cpu代码放入此。

int enqueue(int nBatch, const void * const *inputs, void **outputs, void* workspace, cudaStream_t stream) override;

四个注册到pluginFactory的信息

set/getPluginNamespace: 为plugin设置namespace名字,如果不设置则默认是"",需要注意的是同一个namespace下的plugin的名字相同会冲突。 getPluginType:获取plugin的name getPluginVersion: 获取plugin的版本

void setPluginNamespace(const char* szNamespace) override {}
const char* getPluginNamespace() const override {return "";}
const char* getPluginType() const override {return "AddPlugin";}
const char* getPluginVersion() const override {return "0";}

获取plugin的信息

getNbOutputs:获取plugin输出的个数,这个根据plugin的功能事先决定

int getNbOutputs() const override 
{
        return 1;
}

getOutputDataType:获取plugin输出数据的类型是否满足要求

nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override 
{
    	return inputTypes[0] == nvinfer1::DataType::kFLOAT ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kINT8;
   }

getWorkspaceSize:获取plugin运行占用的显存大小,需要确定这个op需要多大的显存空间去运行,在实际运行的时候就可以直接使用TensorRT开辟好的空间而不是自己去申请显存空间。

size_t getWorkspaceSize(int nMaxBatchSize) const override {return 0;}
  • initialize

初始化函数,在这个plugin准备开始运行之前执行。

int initialize() override {return 0;}
  • clone

将plugin对象克隆一份给TensorRT的builder,network和engine。

 nvinfer1::IPluginV2IOExt* clone() const override 
{
        return new AddPlugin(&m, sizeof(m));
}
  • serialize

将plugin中的参数序列化写入buffer文件中

virtual void serialize(void *buffer) const override {
        memcpy(buffer, &m, sizeof(m));
    }

getSerializationSize:得到plugin中参数的内存大小,返回序列化时需要写多少字节到buffer中。(第二个析构函数)

virtual size_t getSerializationSize() const override 
{
        return sizeof(m);
    }

两个plugin结束处理函数

terminate:继承父类,无操作 destroy:用于销毁plugin的对象

void terminate() override {}
void destroy() override { delete this; }

四个不重要的函数

bool canBroadcastInputAcrossBatch(int inputIndex) const override {return false;}
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const {return false;}
void attachToContext(cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, nvinfer1::IGpuAllocator* /*allocator*/) {}
void detachFromContext() {}

3. AddPluginCreator类实现

参考链接:TensorRT: nvinfer1::IPluginCreator Class Reference (nvidia.com)

构造函数和析构函数

构造函数用于初始化需要传入plugin中的权重和参数。

MyCustomPluginCreator::MyCustomPluginCreator()
{
    mPluginAttributes.emplace_back(PluginField("in_channel", nullptr, PluginFieldType::kFLOAT32, 1));
    mPluginAttributes.emplace_back(PluginField("weight", nullptr, PluginFieldType::kFLOAT32, 1));
    mPluginAttributes.emplace_back(PluginField("bias", nullptr, PluginFieldType::kFLOAT32, 1));
    
    mFC.nbFields = mPluginAttributes.size();
    mFC.fields = mPluginAttributes.data();
}

四个plugin的相关信息获取/设定

const char* getPluginName() const override {return "AddPlugin";}
const char* getPluginVersion() const override {return "0";}

void setPluginNamespace(const char* szNamespace) override {}
const char* getPluginNamespace() const override {return "";}
  • createPlugin

通过PluginFieldCollection将plugin需要的权重和参数,并调用插件类的第一个构造函数创建plugin。

  nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override {
        std::cout << __FUNCTION__ << std::endl;
        float valueToAdd = 0;
        for (int i = 0; i < fc->nbFields; i++) {
            if (!strcmp(fc->fields[i].name, "valueToAdd")) {
                valueToAdd = *(float *)fc->fields[i].data;
            }
        }
        return new AddPlugin({nvinfer1::DataType::kFLOAT, &valueToAdd, 1});
    }
  • deserializePlugin

从保存的engine文件中反序列化数据

nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override {
        return new AddPlugin(serialData, serialLength);
    }
  • getFieldNames

用于一系列PluginFiled对象,传入createPlugin中,创建plugin对象

const nvinfer1::PluginFieldCollection* getFieldNames() override {
        std::cout << __FUNCTION__ << std::endl;
        return nullptr;
    }

4. plugin的注册

4.1 本地plugin注册

当我们只想在某个项目中使用该plugin,可以通过在插件的实现cpp或cu文件中添加如下代码完成plugin的注册。

REGISTER_TENSORRT_PLUGIN(AddPluginCreator);

4.2 TensorRT-OSS注册

编写TensorRT-OSS的plugin时,插件类有时继承的类不同,而插件工厂类则继承的是BaseCreator。

4.3 onnx-tensorrt注册

onnx-tensorrt中的builtin_op_importers.cpp文件中,我们采用DEFINE_BUILTIN_OP_IMPORTER去注册op,然后通过parse解析onnx模型,根据注册好的op去一个个解析并构建模型。

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值