tensorrt learn
plugin官方
官方下边有三个类IPluginCreator,IPluginRegistry和IPluginFactory
IPluginCreator
底下有PluginFieldType和PluginFieldCollection派生类。
PluginFieldType的成员变量有name、data、type、size。
PluginFieldCollection包括append()、extend()、insert()、pop()函数,其中的操作对象都是PluginFieldType类型
IPluginCreator
主要的成员变量有tensorrt_version、name、plugin_version、field_names、plugin_namespace
主要的成员函数有create_plugin(const char* name,const nvinfer1::PluginFieldCollection *fc)
name – The name of the plugin.
field_collection – The PluginFieldCollection for this plugin.
Returns
IPluginV2 or None on failure.
这里的create_plugin()返回的是序列化的构造函数
virtual nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLenth) override;
Returns
A new IPluginV2
这里的deserializePlugin返回的是反序列化的构造函数
creator要实现的函数大概如下:
class PReLUPluginCreator : public nvinfer1::IPluginCreator {
public:
PReLUPluginCreator();
// ------------------inherit from IPluginCreator-------------------
// return the plugin type + plugin namesapce
virtual const char* getPluginName() const override;
// return the plugin version
virtual const char* getPluginVersion() const override;
// return a list of fields that needs to be passed to createPlugin
virtual const nvinfer1::PluginFieldCollection* getFieldNames() override;
// return nullptr in case of error
virtual nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override;
// Called during deserialization of plugin layer. Return a plugin object.
virtual nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLenth) override;
// Set the namespace of the plugin creator based on the plugin library it belongs to. This can be set while registering the plugin creator
virtual void setPluginNamespace(const char* pluginNamespace) override {
}
// Return the namespace of the plugin creator object.
virtual const char* getPluginNamespace() const override;
private:
nvinfer1::PluginFieldCollection mFC;
std::vector<nvinfer1::PluginField> mPluginAttributes;
};
不要忘记注册REGISTER_TENSORRT_PLUGIN(PReLUPluginCreator);
IPluginRegistry
貌似别人实现的接口文件中,没有重写此方法,先保留官方描述
IPluginFactory
factory工厂,大概就是读入模型文件中的层名判断其是否在trt自带的层和plugin层中
IPluginV2* PluginFactory::createPlugin(const char *layerName, const Weights* weights, int nbWeights, const char* libNamespace)
IPluginV2* PluginFactory::createPlugin(const char* layerName, const void* serialData, size_t serialLength) override;
之前v1是有两个实现都是调用自定义类的构造函数,V2中好像参数只有serialized数据buff,但别人的实现好像是weights
主要要实现以下方法
class PluginFactory : public nvcaffeparser1::IPluginFactoryV2 {
public:
PluginFactory(TrtPluginParams params);
virtual ~PluginFactory() {
}
// ------------------inherit from IPluginFactoryV2--------------------
// determines if a layer configuration is provided by an IPluginV2
virtual bool isPluginV2(const char* layerName) override;
// create a plugin
virtual IPluginV2* createPlugin(const char* layerName, const Weights* weights, int nbWeights, const char* libNamespace="") override;
private:
//这里是你的plugin中的参数
// yolo-det layer params
int mYoloClassNum;
int mYolo3NetSize;
// upsample layer params
float mUpsampleScale;
};
自定义plugin
plugin类要实现的操作如下
class PReLUPlugin : public nvinfer1::IPluginV2
{
public:
// @参数: weights 和 nbWeight这两个参数是PluginFactory::createPlugin的参数,可以参见PluginFactory.cpp.如果你的自定义层没有权重,那么这两个参数你不要也可以,这个函数主要就是用来将权重和自定义层的其他参数读取到内部变量里面*/
// 个人理解这个构造函数的作用是将weights等参数序列化
PReLUPlugin(const nvinfer1::Weights* weights, int nbWeight);
// 这个就是从序列化数据里面恢复plugin的相关数据,另一个函数serialize,将类的数据写入到序列化数据里面.在IPluginCreator::deserializePlugin里面会调用到这个函数,注意写的顺序跟读的顺序必须是一样的.
// 个人理解这个构造函数的作用是将序列化后的参数反序列化成weights等
PReLUPlugin(const void* data, size_t length);
// 返回在序列化你的自定义插件的时候,需要占用到多少空间,其实就是你的权重和一些必要的成员变量的空间
virtual size_t getSerializationSize() const override;
// 序列化你的自定义插件到buffer,需要保证write的顺序和read的顺序是一样的
virtual void serialize(void* buffer) const override;
PReLUPlugin() = delete;
~PReLUPlugin();
// 返回输出tensor的数量, 比如说prelu,输出个数跟relu一样是1,这个取决于你的自定义层.
virtual int getNbOutputs() const override;
// @描述 返回输出tensor的维度,很多时候都取决于输入维度.对于prelu来说,输出维度等于输入维度.
// @参数 index 输出tensor的index
// @参数 inputs 输出tensors的纬度.注意有可能有多个输入
// @参数 nbInputDims 输出tensors的个数.
virtual nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override;
// 查询对应的datatype和format是否支持, 这个取决于你的自定义层实现是否支持.
virtual bool supportsFormat(const nvinfer1::DataType type, nvinfer1::PluginFormat format) const override;
virtual void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs