tensorrt 官方学习

plugin官方

官方下边有三个类IPluginCreator,IPluginRegistry和IPluginFactory

IPluginCreator

底下有PluginFieldType和PluginFieldCollection派生类。
PluginFieldType的成员变量有name、data、type、size。
PluginFieldCollection包括append()、extend()、insert()、pop()函数,其中的操作对象都是PluginFieldType类型
IPluginCreator

主要的成员变量有tensorrt_version、name、plugin_version、field_names、plugin_namespace
主要的成员函数有create_plugin(const char* name,const nvinfer1::PluginFieldCollection *fc) 
name – The name of the plugin.
field_collection – The PluginFieldCollection for this plugin.
Returns
IPluginV2 or None on failure.

这里的create_plugin()返回的是序列化的构造函数

 virtual nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLenth) override;
 Returns
A new IPluginV2

这里的deserializePlugin返回的是反序列化的构造函数
creator要实现的函数大概如下:

class PReLUPluginCreator : public nvinfer1::IPluginCreator {
   
public:
    PReLUPluginCreator();

    // ------------------inherit from IPluginCreator-------------------
    // return the plugin type + plugin namesapce
    virtual const char* getPluginName() const override;
    // return the plugin version
    virtual const char* getPluginVersion() const override;

    // return a list of fields that needs to be passed to createPlugin
    virtual const nvinfer1::PluginFieldCollection* getFieldNames() override;

    // return nullptr in case of error
    virtual nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override;

    // Called during deserialization of plugin layer. Return a plugin object.
    
    virtual nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLenth) override;

    // Set the namespace of the plugin creator based on the plugin library it belongs to. This can be set while registering the plugin creator
    virtual void setPluginNamespace(const char* pluginNamespace) override {
   }

    // Return the namespace of the plugin creator object.
    virtual const char* getPluginNamespace() const override;

private:
    nvinfer1::PluginFieldCollection mFC;
    std::vector<nvinfer1::PluginField> mPluginAttributes;
};

不要忘记注册REGISTER_TENSORRT_PLUGIN(PReLUPluginCreator);

IPluginRegistry

貌似别人实现的接口文件中,没有重写此方法,先保留官方描述

IPluginFactory

factory工厂,大概就是读入模型文件中的层名判断其是否在trt自带的层和plugin层中

IPluginV2* PluginFactory::createPlugin(const char *layerName, const Weights* weights, int nbWeights, const char* libNamespace)
IPluginV2* PluginFactory::createPlugin(const char* layerName, const void* serialData, size_t serialLength) override;

之前v1是有两个实现都是调用自定义类的构造函数,V2中好像参数只有serialized数据buff,但别人的实现好像是weights
主要要实现以下方法

class PluginFactory : public nvcaffeparser1::IPluginFactoryV2 {
   
public:
    PluginFactory(TrtPluginParams params);

    virtual ~PluginFactory() {
   }
    // ------------------inherit from IPluginFactoryV2--------------------
    // determines if a layer configuration is provided by an IPluginV2
    virtual bool isPluginV2(const char* layerName) override;

    // create a plugin
    virtual IPluginV2* createPlugin(const char* layerName, const Weights* weights, int nbWeights, const char* libNamespace="") override;

private:
    //这里是你的plugin中的参数
    // yolo-det layer params
    int mYoloClassNum;
    int mYolo3NetSize;
    // upsample layer params
    float mUpsampleScale;
};

自定义plugin

plugin类要实现的操作如下

class PReLUPlugin : public nvinfer1::IPluginV2
{
   
public:

    //  @参数: weights 和 nbWeight这两个参数是PluginFactory::createPlugin的参数,可以参见PluginFactory.cpp.如果你的自定义层没有权重,那么这两个参数你不要也可以,这个函数主要就是用来将权重和自定义层的其他参数读取到内部变量里面*/
    //  个人理解这个构造函数的作用是将weights等参数序列化
    PReLUPlugin(const nvinfer1::Weights* weights, int nbWeight);
    // 这个就是从序列化数据里面恢复plugin的相关数据,另一个函数serialize,将类的数据写入到序列化数据里面.在IPluginCreator::deserializePlugin里面会调用到这个函数,注意写的顺序跟读的顺序必须是一样的.
    // 个人理解这个构造函数的作用是将序列化后的参数反序列化成weights等
    PReLUPlugin(const void* data, size_t length);
	// 返回在序列化你的自定义插件的时候,需要占用到多少空间,其实就是你的权重和一些必要的成员变量的空间
    virtual size_t getSerializationSize() const override;
    //  序列化你的自定义插件到buffer,需要保证write的顺序和read的顺序是一样的  
    virtual void serialize(void* buffer) const override;
  
    PReLUPlugin() = delete;

    ~PReLUPlugin();
    // 返回输出tensor的数量, 比如说prelu,输出个数跟relu一样是1,这个取决于你的自定义层.
    virtual int getNbOutputs() const override;  
    //   @描述 返回输出tensor的维度,很多时候都取决于输入维度.对于prelu来说,输出维度等于输入维度.
    //   @参数 index 输出tensor的index
    //   @参数 inputs 输出tensors的纬度.注意有可能有多个输入
    //   @参数 nbInputDims 输出tensors的个数.
    virtual nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override;


    //  查询对应的datatype和format是否支持, 这个取决于你的自定义层实现是否支持.
    
    virtual bool supportsFormat(const nvinfer1::DataType type, nvinfer1::PluginFormat format) const override;

    virtual void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小涵涵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值