本文源于学习TensorRT文档《TensorRT-Developer-Guide》第4章“EXTENDING TENSORRT WITH CUSTOM LAYERS”的理解。
通过C++API添加自定义层
自定义层添加是通过扩展IPluginV2Ext和IPluginCreator类来实现:
- IPluginV2Ext:IPluginV2的升级版,实现自定义插件的基类,包含版本化和对其它格式和单精度的处理;
- IPluginCreator:自定义层的创建类,可以通过它获取插件的名称、版本信息、参数等,也提供网络创建阶段创建插件的方法,并在推理阶段反序列化它。
对定义好的插件可以通过REGISTER_TENSORRT_PLUGIN(pluginCreator)
进行静态注册,并在使用时通过getPluginRegistry()
查询并使用。官方已经实现的插件有:
- RPROI_TRT
- Normalize_TRT
- PriorBox_TRT
- GridAnchor_TRT
- NMS_TRT
- LReLU_TRT
- Reorg_TRT
- Region_TRT
- Clip_TRT
// 通过getPluginRegistry获取所有TensorRT插件,creator即IPluginCreator对象
auto creator = getPluginRegistry()->getPluginCreator(pluginName, pluginVersion);
const PluginFieldCollection* pluginFC = creator->getFieldNames();
// 填充该层参数信息,pluginData需要先通过PluginField分配堆上空间
PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields);
// 使用层名和插件参数创建新的插件对象,创建在堆上,需要主动释放
IPluginV2 *pluginObj = creator->createPlugin(layerName, pluginData);
// 在网络上添加一层,并将该层和插件绑定,layer即IPluginV2Layer对象
auto layer = network.addPluginV2(&inputs[0], int(inputs.size()), pluginObj);
// TODO:创建最新的网络,并序列化引擎
// 销毁插件对象
pluginObj->destroy()
// TODO:释放TensorRT资源,network、engine、builder
// TODO:释放显存空间,如原网络参数信息pluginData
TensorRT的引擎会在序列化时内部存储IPluginV2插件的属性信息,并在反序列化时通过插件注册表进行查找,并通过IPluginV2::destroy()接口内部销毁。
过去的版本中,用户必须通过nvinfer1::IPluginFactory类在反序列化时创建插件,现在的TensorRT版本可以使用addPluginV2即可。例如:
// 使用Caffe解释器解析网络并添加插件
// 如果使用IPluginExt创建插件,需要搭配nvinfer1::IPluginFactory 和 nvinfer1::IPluginFactory
class FooPlugin : public IPluginExt
{
// TODO:创建插件实现方法
};
class MyPluginFactory :
public nvinfer1::IPluginFactory,
public nvcaffeparser1::IPluginFactoryExt
{
// TODO:创建插件的工厂方法
};
// 如果使用IPluginV2创建并注册插件,则不再需要实现nvinfer1::IPluginFactory,
// 但需要通过nvcaffeparser1::IPluginFactoryV2 和 IPluginCreator来完成注册
class FooPlugin : public IPluginV2
{
// TODO:创建插件实现方法
};
class FooPluginFactory : public nvcaffeparser1::IPluginFactoryV2
{
virtual nvinfer1::IPluginV2* createPlugin(...)
{
// TODO:创建并返回插件对象,如FooPlugin
}
bool isPlugin(const char* name)
{
// TODO:通过网络层的名字检验是否使用该插件
}
}
class FooPluginCreator : public IPluginCreator
{
// TODO:实现所有的插件创建
};
REGISTER_TENSORRT_PLUGIN(FooPluginCreator);
具体的插件创建实例可以查看:
- samplePlugin:自定义Caffe网络插件方法;
- sampleFasterRCNN:通过TensorRT注册Caffe网络插件;
- sampleUffSSD:对UFF(针对TensorFlow)添加插件。
使用自定义插件
该部分内容基本与创建时介绍的情况雷同,需要注意的是对于Caffe解释器,可以通过setPluginFactoryV2 和 IPluginFactoryV2使用自定义插件,那么在反序列化时创建的插件会按照 IPluginExt::destroy()中定义的内容内部销毁而无需手动调用,用户只需要销毁创建创建过程中的插件对象。
API描述
IPluginV2的API
1、获取插件输出数据结构,检验是否可以和相邻层对接:
- getNbOutputs:验证输出张量数目;
- getOutputDimensions:验证输入维度,获取输出维度;
- supportsFormat:设置插件支持的数据类型,如何种处理精度;
- getOutputDataType:插件输出数据的类型(NCHW、NC/2HW2 、NHWC8等,见PluginFormatType)。
2、获取插件除了输入输出外,需要占用多大的空间存储数据,在builder中调用并预分配:
- getWorkspaceSize
3、插件在创建阶段会多次配置、初始化、执行、中止,而运行时只会多次执行,配置、初始化、中止只执行一次,initialize申请的内存需要在terminate时被释放,其它的内存需要在destroy释放,所需要的插件为:
- configurePlugin:配置输入输出属性(数量、维度、类型、广播、格式选择、最大BatchSize),插件会选择最合适的算法和数据结构;
- initialize:在插件配置和推理引擎创建之后使用,根据设置的数据结构配置并准备执行;
- enqueue:插件实际处理过程,需输入运行BatchSize、输入指针、输出指针、缓存空间指针、CUDA流;
- terminate:在引擎的上下文被释放时释放插件的所有资源;
- clone:在需要一个独立插件时(新的builder、network、engine被创建)使用;
- destroy:在builder、network、engine销毁时调用,释放对应的插件资源;
- set/getPluginNamespace:设置或获取插件的命名空间,默认为""(空)。
4、通过IPluginV2Ext可以实现输入输出的广播性质,需要实现:
- canBroadcastInputAcrossBatch:判断输入张量是否可以在批中进行广播,能广播则返回true,TensorRT不会复制输入并使用同一输入副本;不能广播返回false,TensorRT会复制输入张量;
- isOutputBroadcastAcrossBatch:指定索引的输出是否被广播。
IPluginCreator的API
IPluginCreator中用来从插件库中查找并创建插件的方法:
- getPluginName:获取插件的名字,并和getPluginType配合使用;
- getPluginVersion:返回插件版本,TensorRT内部插件默认为1;
- getFieldNames:返回PluginFieldCollection结构数据,包含添加插件的参数名和类型;
- createPlugin:通过给定的PluginFieldCollection结构参数创建插件,需填充实际所需参数;
- deserializePlugin:在TensorRT引擎根据插件名和版本信息内部调用,返回用于推理的插件对象;
- set/getPluginNamespace:creator所在的插件库命名空间,默认为""(空)。
从5.x.x迁移到5.1.x
5.x.x版本中没有getOutputDataType、isOutputBroadcastAcrossBatch、canBroadcastInputAcrossBatch,configurePlugin是针对configureWithFormat的升级。在迁移到5.1.x时需要实现这些新特性。
virtual nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const = 0;
virtual bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const = 0;
virtual bool canBroadcastInputAcrossBatch(int inputIndex) const = 0;
virtual void configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast, const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize) = 0;