前言
本文为学习Nvidia官方视频教程的学习笔记,分享TensorRT手写Plugin的详细步骤,文中代码参考Nvidia提供的cookbook中05-Plugin/API的代码,强烈建议您先观看Nvidia的教程视频
第三节,然后学习本文档。
一、plugin_creator_list
首先介绍一下Plugin_creator_list,这个列表中保存了TensorRT提供的官方Plugin,我们可以通过add_plugin_v2
直接调用这里面已有的Plugin ,如果遇到TensorRT不支持的算子时,我们则需要手写Plugin并注册到plugin_creator_list中,最后通过同样的方式调用即可。
plugin_creator_list
的结构图如下, plugin_creator_list 中有多个TensorRT提供的Plugin,每个Plugin都封装在一个creator
中,每个creator中保存了相应的版本信息、命名空间、Plugin版本、Plugin名称,另外在 pluginField中保存了Plugin的数据类型、所需参数,我们可以通过配置pluginField来配置Plugin。
二、手写Plugin具体步骤
1.手写Plugin核函数
代码如下(对应cookbook中的trt-samples-for-hackathon-cn\cookbook\05-Plugin\API\ AddScalarPlugin.cu):
#include "AddScalarPlugin.h"
// kernel for GPU
__global__ void addScalarKernel(const float *input, float *output, const float scalar, const int nElement)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= nElement)
return;
float _1 = input[index];
float _2 = _1 + scalar;
output[index] = _2;
}
namespace nvinfer1
{
// class AddScalarPlugin
AddScalarPlugin::AddScalarPlugin(const std::string &name, float scalar):
name_(name)
{
WHERE_AM_I();
m_.scalar = scalar;
}
AddScalarPlugin::AddScalarPlugin(const std::string &name, const void *buffer, size_t length):
name_(name)
{
WHERE_AM_I();
memcpy(&m_, buffer, sizeof(m_));
}
AddScalarPlugin::~AddScalarPlugin()
{
WHERE_AM_I();
}
IPluginV2DynamicExt *AddScalarPlugin::clone() const noexcept
{
WHERE_AM_I();
auto p = new AddScalarPlugin(name_, &m_, sizeof(m_));
p->setPluginNamespace(namespace_.c_str());
return p;
}
int32_t AddScalarPlugin::getNbOutputs() const noexcept
{
WHERE_AM_I();
return 1;
}
DataType AddScalarPlugin::getOutputDataType(int32_t index, DataType const *inputTypes, int32_t nbInputs) const noexcept
{
WHERE_AM_I();
return inputTypes[0];
}
DimsExprs AddScalarPlugin::getOutputDimensions(int32_t outputIndex, const DimsExprs *inputs, int32_t nbInputs, IExprBuilder &exprBuilder) noexcept
{
WHERE_AM_I();
return inputs[0];
}
bool AddScalarPlugin::supportsFormatCombination(int32_t pos, const PluginTensorDesc *inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
{
WHERE_AM_I();
bool res;
switch (pos)
{
case 0:
res = inOut[0].type == DataType::kFLOAT && inOut[0].format == TensorFormat::kLINEAR;
break;
case 1:
res = inOut[1].type == inOut[0].type && inOut[1].format == inOut[0].format;
break;
default: // should NOT be here!
res = false;
}
#ifdef DEBUG
std::cout << "\tpos=" << pos << ",res=" << res << "->[";
for (int i = 0; i < nbInputs + nbOutputs; ++i)
{
std::cout << formatToString(inOut[i].format) << ",";
}
std::cout << "],[";
for (int i = 0; i < nbInputs + nbOutputs; ++i)
{
std::cout << dataTypeToString(inOut[i].type) << ",";
}
std::cout << "]" << std::endl;
#endif
return res;
}
void AddScalarPlugin::configurePlugin(const DynamicPluginTensorDesc *in, int32_t nbInputs, const DynamicPluginTensorDesc *out, int32_t nbOutputs) noexcept
{
WHERE_AM_I();
return;
}
size_t AddScalarPlugin::getWorkspaceSize(const PluginTensorDesc *inputs, int32_t nbInputs, const PluginTensorDesc *outputs, int32_t nbOutputs) const noexcept
{
WHERE_AM_I();
return 0;
}
int32_t AddScalarPlugin::enqueue(const PluginTensorDesc *inputDesc, const PluginTensorDesc *outputDesc, const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept
{
WHERE_AM_I();
int nElement = 1;
for (int i = 0; i < inputDesc[0].dims.nbDims; ++i)
{
nElement *= inputDesc[0].dims.d[i];
}
dim3 grid(CEIL_DIVIDE(nElement, 256), 1, 1), block(256, 1, 1);
addScalarKernel<<<grid, block, 0, stream>>>(reinterpret_cast<const float *>(inputs[0]), reinterpret_cast<float *>(outputs[0]), m_.scalar, nElement);
return 0;
}
void AddScalarPlugin::destroy() noexcept
{
WHERE_AM_I();
delete this;
return;
}
int32_t AddScalarPlugin::initialize() noexcept
{
WHERE_AM_I();
return 0;
}
void AddScalarPlugin::terminate() noexcept
{
WHERE_AM_I();
return;
}
size_t AddScalarPlugin::getSerializationSize() const noexcept
{
WHERE_AM_I();
return sizeof(m_);
}
void AddScalarPlugin::serialize(void *buffer) const noexcept
{
WHERE_AM_I();
memcpy(buffer, &m_, sizeof(m_));
return;
}
void AddScalarPlugin::setPluginNamespace(const char *pluginNamespace) noexcept
{
WHERE_AM_I();
namespace_ = pluginNamespace;
return;
}
const char *AddScalarPlugin::getPluginNamespace() const noexcept
{
WHERE_AM_I();
return namespace_.c_str();
}
const char *AddScalarPlugin::getPluginType() const noexcept
{
WHERE_AM_I();
return PLUGIN_NAME;
}
const char *AddScalarPlugin::getPluginVersion() const noexcept
{
WHERE_AM_I();
return PLUGIN_VERSION;
}
void AddScalarPlugin::attachToContext(cudnnContext *contextCudnn, cublasContext *contextCublas, IGpuAllocator *gpuAllocator) noexcept
{
WHERE_AM_I();
return;
}
void AddScalarPlugin::detachFromContext() noexcept
{
WHERE_AM_I();
return;
}
// class AddScalarPluginCreator
PluginFieldCollection AddScalarPluginCreator::fc_ {};
std::vector<PluginField> AddScalarPluginCreator::attr_;
AddScalarPluginCreator::AddScalarPluginCreator()
{
WHERE_AM_I();
attr_.clear();
attr_.emplace_back(PluginField("scalar", nullptr, PluginFieldType::kFLOAT32, 1));
fc_.nbFields = attr_.size();
fc_.fields = attr_.data();
}
AddScalarPluginCreator::~AddScalarPluginCreator()
{
WHERE_AM_I();
}
IPluginV2DynamicExt *AddScalarPluginCreator::createPlugin(const char *name, const PluginFieldCollection *fc) noexcept
{
WHERE_AM_I();
float scalar = 0;
std::map<std::string, float *> parameterMap {{"scalar", &scalar}};
for (int i = 0; i < fc->nbFields; ++i)
{
if (parameterMap.find(fc->fields[i].name) != parameterMap.end())
{
*parameterMap[fc->fields[i].name] = *reinterpret_cast<const float *>(fc->fields[i].data);
}
}
AddScalarPlugin *pObj = new AddScalarPlugin(name, scalar);
pObj->setPluginNamespace(namespace_.c_str());
return pObj;
}
IPluginV2DynamicExt *AddScalarPluginCreator::deserializePlugin(const char *name, const void *serialData, size_t serialLength) noexcept
{
WHERE_AM_I();
AddScalarPlugin *pObj = new AddScalarPlugin(name, serialData, serialLength);
pObj->setPluginNamespace(namespace_.c_str());
return pObj;
}
void AddScalarPluginCreator::setPluginNamespace(const char *pluginNamespace) noexcept
{
WHERE_AM_I();
namespace_ = pluginNamespace;
return;
}
const char *AddScalarPluginCreator::getPluginNamespace() const noexcept
{
WHERE_AM_I();
return namespace_.c_str();
}
const char *AddScalarPluginCreator::getPluginName() const noexcept
{
WHERE_AM_I();
return PLUGIN_NAME;
}
const char *AddScalarPluginCreator::getPluginVersion() const noexcept
{
WHERE_AM_I();
return PLUGIN_VERSION;
}
const PluginFieldCollection *AddScalarPluginCreator::getFieldNames() noexcept
{
WHERE_AM_I();
return &fc_;
}
REGISTER_TENSORRT_PLUGIN(AddScalarPluginCreator);
} // namespace nvinfer1
由于篇幅限制,相应的.h头问题请参考cookbook,通过Plugin的核函数,我们需要生成相应的Plugin.so
的库文件,方便后续的Plugin读取和使用。
2.读取Plugin
上一步我们生成了相应的Plugin.so库文件,这一步我们通过调用.so文件,实现对Plugin的注册和配置,最终实现在网络中插入自己手写的Plugin。
Plugin的初始化、注册以及配置的具体流程如下:
创建完成Plugin后,我们还可以进行序列化和反序列化的操作:
pluginString = plugin.serialize()
plugin = creator.deserialize_plugin(creator.name, pluginString) # create a plugin by memory of serialized plugin
最后将Plugin插入到Network中:
pluginLayer = network.add_plugin_v2([inputT0], plugin)
详细的代码如下(对应cookbook中的trt-samples-for-hackathon-cn\cookbook\05-Plugin\API\ main.py):
import ctypes
import os
from glob import glob
import numpy as np
import tensorrt as trt
from cuda import cudart
soFile = "./AddScalarPlugin.so"
np.set_printoptions(precision=3, linewidth=200, suppress=True)
np.random.seed(31193)
cudart.cudaDeviceSynchronize()
def getAddScalarPlugin(scalar):
for c in trt.get_plugin_registry().plugin_creator_list:
#print(c.name)
if c.name == "AddScalar":
parameterList = []
parameterList.append(trt.PluginField("scalar", np.float32(scalar), trt.PluginFieldType.FLOAT32))
return c.create_plugin(c.name, trt.PluginFieldCollection(parameterList))
return None
# os.chdir("/w/gitlab/tensorrt-cookbook/05-Plugin/API/")
# Load default plugin creators
logger = trt.Logger(trt.Logger.ERROR)#创建logger
trt.init_libnvinfer_plugins(logger, '')#初始化plugins
pluginRegistry = trt.get_plugin_registry()#获取plugin注册表
print("Count of default plugin creators = %d" % len(pluginRegistry.plugin_creator_list))#输出plugin注册表中已有的plugin
# Attributions of Plugin Registry
print("pluginRegistry.error_recorder =", pluginRegistry.error_recorder) # 检查registry是否有错误 ErrorRecorder can be set into EngineInspector, usage of ErrorRecorder refer to 02-API/ErrorRecorder
pluginRegistry.parent_search_enabled = True # whether search plugin creators in parent directory, default value is True
# Load local plugin creators
for soFile in glob("./*.so"):#读取自己创建的plugin
if True: # common method
ctypes.cdll.LoadLibrary(soFile)
else: # use TensorRT API, but there are some problems, do not use this temporarily
handle = pluginRegistry.load_library(soFile)
#pluginRegistry.deregister_library(handle) # deregiste the library
print("Count of total plugin creators = %d" % len(pluginRegistry.plugin_creator_list)) #长度+1 / one more plugin creator "AddScalar" added
#pluginRegistry.deregister_library(?) # deregiste the library
# print information of all plugin creators
print("TensorRTVersion Namespace PluginVersion Name")
for creator in pluginRegistry.plugin_creator_list:
print("%4s %s %s %s" % (creator.tensorrt_version, ("\"\"" if creator.plugin_namespace == "" else creator.plugin_namespace), creator.plugin_version, creator.name))
for creator in pluginRegistry.plugin_creator_list:
if creator.name == "AddScalar" and creator.plugin_version == "1": # check name and version during selecting plugin
# print the necessary parameters for creating the plugin
for i, pluginField in enumerate(creator.field_names):
print("%2d->%s, %s, %s, %s" % (i, pluginField.name, pluginField.type, pluginField.size, pluginField.data))
#i=0; pluginField.name=scalar; pluginField.type=PluginFieldType.FLOAT32; pluginField.size= 1;pluginField.data= None;
# We can registe and deregiste a plugin creator in Plugin Registry, but not required
#pluginRegistry.deregister_creator(creator) # deregiste the plugin creator
#pluginRegistry.register_creator(creator) # registe the plugin creator again
# feed the PluginCreator with parameters
pluginFieldCollection = trt.PluginFieldCollection()#配置pluginField
pluginField = trt.PluginField("scalar", np.float32(1.0), trt.PluginFieldType.FLOAT32)
# tensorrt.PluginFieldType: FLOAT16, FLOAT32, FLOAT64, INT8, INT16, INT32, CHAR, DIMS, UNKNOWN
print(pluginField.name, pluginField.type, pluginField.size, pluginField.data)#scalar PluginFieldType.FLOAT32 1 <capsule object NULL at 0x7fea9d292de0>
pluginFieldCollection.append(pluginField) # use like a list
#pluginFieldCollection.insert(1,pluginField)
#pluginFieldCollection.extend([pluginField])
#pluginFieldCollection.clear()
#pluginFieldCollection.pop(1)
plugin = creator.create_plugin(creator.name, pluginFieldCollection) # create a plugin by parameters
plugin.__class__ = trt.IPluginV2Ext # change class of plugin from IPluginV2 to IPluginV2Ext, we still do not have IPluginV2Dynamic class
# methods not work in python API
# plugin.supports_format(trt.float32, None) # nvinfer1::TensorFormat::kLINEAR
#plugin.attach_to_context(None, None)
#plugin.detach_from_context()
#plugin.configure_with_format([[2]], [[2]], trt.float32, None, 1) # nvinfer1::TensorFormat::kLINEAR
#plugin.configure_plugin([[2]],[[2]],[trt.float32],[trt.float32],[False],[False], None, 1) # nvinfer1::TensorFormat::kLINEAR
#plugin.execute_async(1, [None], [None], None, 0) # address of input / output / workspace memory
#plugin.initialize()
#plugin.terminate()
#plugin.destroy()
# methods work (but useless) in python API
print("plugin.plugin_type =", plugin.plugin_type)
print("plugin.plugin_namespace =", plugin.plugin_namespace)
print("plugin.plugin_version =", plugin.plugin_version)
print("plugin.num_outputs =", plugin.num_outputs)
print("plugin.serialization_size =", plugin.serialization_size)
print("plugin.tensorrt_version =", plugin.tensorrt_version)
print("plugin.clone() =", plugin.clone())
print("plugin.get_output_data_type(0, [trt.float32]) =", plugin.get_output_data_type(0, [trt.float32]))
print("plugin.get_output_shape(0, [trt.Dims([2])])) =", plugin.get_output_shape(0, [trt.Dims([2])])) # output is always ((0))?
print("plugin.get_workspace_size(1) =", plugin.get_workspace_size(1)) # output is always 0?
pluginString = plugin.serialize()
plugin = creator.deserialize_plugin(creator.name, pluginString) # create a plugin by memory of serialized plugin
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
profile = builder.create_optimization_profile()
config = builder.create_builder_config()
inputT0 = network.add_input("inputT0", trt.float32, [-1])
profile.set_shape(inputT0.name, [1], [2], [4])
config.add_optimization_profile(profile)
pluginLayer = network.add_plugin_v2([inputT0], plugin)
print(pluginLayer.plugin) # other members and methods refer to 02-API/Layer
print("Finish")
# methods not work
#trt.get_builder_plugin_registry(None) # nvinfer1::EngineCapability
输出结果:
python3 ./main.py
Count of default plugin creators = 67
pluginRegistry.error_recorder = None
Count of total plugin creators = 68
TensorRTVersion Namespace PluginVersion Name
8601 "" 1 CaskDeconvShaderWeightsTransformerPlugin
8601 "" 1 CaskConvShaderWeightsTransformerPlugin
8601 "" 1 RNNTEncoderPlugin
8601 "" 1 SmallTileGEMM_TRT
8601 "" 1 DLRM_BOTTOM_MLP_TRT
8601 "" 1 CustomQKVToContextPluginDynamic
8601 "" 2 CustomQKVToContextPluginDynamic
8601 "" 3 CustomQKVToContextPluginDynamic
8601 "" 1 CustomSkipLayerNormPluginDynamic
8601 "" 2 CustomSkipLayerNormPluginDynamic
8601 "" 3 CustomSkipLayerNormPluginDynamic
8601 "" 4 CustomSkipLayerNormPluginDynamic
8601 "" 1 SingleStepLSTMPlugin
8601 "" 1 RnRes2FullFusion_TRT
8601 "" 1 RnRes2Br2bBr2c_TRT
8601 "" 2 RnRes2Br2bBr2c_TRT
8601 "" 1 RnRes2Br1Br2c_TRT
8601 "" 2 RnRes2Br1Br2c_TRT
8601 "" 1 GroupNormalizationPlugin
8601 "" 1 CustomGeluPluginDynamic
8601 "" 1 CustomFCPluginDynamic
8601 "" 2 CustomEmbLayerNormPluginDynamic
8601 "" 3 CustomEmbLayerNormPluginDynamic
8601 "" 1 CustomEmbLayerNormPluginDynamic
8601 "" 1 DisentangledAttention_TRT
8601 "" 1 BatchedNMSDynamic_TRT
8601 "" 1 BatchedNMS_TRT
8601 "" 1 BatchTilePlugin_TRT
8601 "" 1 Clip_TRT
8601 "" 1 CoordConvAC
8601 "" 1 CropAndResizeDynamic
8601 "" 1 CropAndResize
8601 "" 1 DecodeBbox3DPlugin
8601 "" 1 DetectionLayer_TRT
8601 "" 1 EfficientNMS_Explicit_TF_TRT
8601 "" 1 EfficientNMS_Implicit_TF_TRT
8601 "" 1 EfficientNMS_ONNX_TRT
8601 "" 1 EfficientNMS_TRT
8601 "" 1 FlattenConcat_TRT
8601 "" 1 GenerateDetection_TRT
8601 "" 1 GridAnchor_TRT
8601 "" 1 GridAnchorRect_TRT
8601 "" 1 InstanceNormalization_TRT
8601 "" 2 InstanceNormalization_TRT
8601 "" 1 LReLU_TRT
8601 "" 1 ModulatedDeformConv2d
8601 "" 1 MultilevelCropAndResize_TRT
8601 "" 1 MultilevelProposeROI_TRT
8601 "" 1 MultiscaleDeformableAttnPlugin_TRT
8601 "" 1 NMSDynamic_TRT
8601 "" 1 NMS_TRT
8601 "" 1 Normalize_TRT
8601 "" 1 PillarScatterPlugin
8601 "" 1 PriorBox_TRT
8601 "" 1 ProposalDynamic
8601 "" 1 ProposalLayer_TRT
8601 "" 1 Proposal
8601 "" 1 PyramidROIAlign_TRT
8601 "" 1 Region_TRT
8601 "" 1 Reorg_TRT
8601 "" 1 ResizeNearest_TRT
8601 "" 1 ROIAlign_TRT
8601 "" 1 RPROI_TRT
8601 "" 1 ScatterND
8601 "" 1 SpecialSlice_TRT
8601 "" 1 Split
8601 "" 1 VoxelGeneratorPlugin
8601 "" 1 AddScalar
0->scalar, PluginFieldType.FLOAT32, 1, None
scalar PluginFieldType.FLOAT32 1 <capsule object NULL at 0x7fea9d292de0>
plugin.plugin_type = AddScalar
plugin.plugin_namespace =
plugin.plugin_version = 1
plugin.clone() = <tensorrt.tensorrt.IPluginV2Ext object at 0x7fea95c2b4f0>
plugin.get_output_data_type(0, [trt.float32]) = DataType.FLOAT
plugin.get_output_shape(0, [trt.Dims([2])])) = (0)
plugin.get_workspace_size(1) = 0
<tensorrt.tensorrt.IPluginV2 object at 0x7fea95c2c4b0>
Finish
总结
手写Plugin难度较高,尤其是Plugin的核函数部分,建议参考Nvidia的教程模板进行修改,另外针对于Plugin的读取,需要大家了解并掌握详细的读取步骤,实现对Plugin的注册和配置,并最终实现将自己手写的Plugin插入到网络中。