TensorRT_Plugin:手写Plugin详细步骤教程


前言

本文为学习Nvidia官方视频教程的学习笔记,分享TensorRT手写Plugin的详细步骤,文中代码参考Nvidia提供的cookbook中05-Plugin/API的代码,强烈建议您先观看Nvidia的教程视频
第三节,然后学习本文档。


一、plugin_creator_list

首先介绍一下Plugin_creator_list,这个列表中保存了TensorRT提供的官方Plugin,我们可以通过add_plugin_v2直接调用这里面已有的Plugin ,如果遇到TensorRT不支持的算子时,我们则需要手写Plugin并注册到plugin_creator_list中,最后通过同样的方式调用即可。

plugin_creator_list 的结构图如下, plugin_creator_list 中有多个TensorRT提供的Plugin,每个Plugin都封装在一个creator中,每个creator中保存了相应的版本信息、命名空间、Plugin版本、Plugin名称,另外在 pluginField中保存了Plugin的数据类型、所需参数,我们可以通过配置pluginField来配置Plugin。
在这里插入图片描述

二、手写Plugin具体步骤

1.手写Plugin核函数

代码如下(对应cookbook中的trt-samples-for-hackathon-cn\cookbook\05-Plugin\API\ AddScalarPlugin.cu):


#include "AddScalarPlugin.h"

// kernel for GPU
__global__ void addScalarKernel(const float *input, float *output, const float scalar, const int nElement)
{
    const int index = blockIdx.x * blockDim.x + threadIdx.x;
    if (index >= nElement)
        return;

    float _1      = input[index];
    float _2      = _1 + scalar;
    output[index] = _2;
}

namespace nvinfer1
{
// class AddScalarPlugin
AddScalarPlugin::AddScalarPlugin(const std::string &name, float scalar):
    name_(name)
{
    WHERE_AM_I();
    m_.scalar = scalar;
}

AddScalarPlugin::AddScalarPlugin(const std::string &name, const void *buffer, size_t length):
    name_(name)
{
    WHERE_AM_I();
    memcpy(&m_, buffer, sizeof(m_));
}

AddScalarPlugin::~AddScalarPlugin()
{
    WHERE_AM_I();
}

IPluginV2DynamicExt *AddScalarPlugin::clone() const noexcept
{
    WHERE_AM_I();
    auto p = new AddScalarPlugin(name_, &m_, sizeof(m_));
    p->setPluginNamespace(namespace_.c_str());
    return p;
}

int32_t AddScalarPlugin::getNbOutputs() const noexcept
{
    WHERE_AM_I();
    return 1;
}

DataType AddScalarPlugin::getOutputDataType(int32_t index, DataType const *inputTypes, int32_t nbInputs) const noexcept
{
    WHERE_AM_I();
    return inputTypes[0];
}

DimsExprs AddScalarPlugin::getOutputDimensions(int32_t outputIndex, const DimsExprs *inputs, int32_t nbInputs, IExprBuilder &exprBuilder) noexcept
{
    WHERE_AM_I();
    return inputs[0];
}

bool AddScalarPlugin::supportsFormatCombination(int32_t pos, const PluginTensorDesc *inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
{
    WHERE_AM_I();
    bool res;
    switch (pos)
    {
    case 0:
        res = inOut[0].type == DataType::kFLOAT && inOut[0].format == TensorFormat::kLINEAR;
        break;
    case 1:
        res = inOut[1].type == inOut[0].type && inOut[1].format == inOut[0].format;
        break;
    default: // should NOT be here!
        res = false;
    }
#ifdef DEBUG
    std::cout << "\tpos=" << pos << ",res=" << res << "->[";
    for (int i = 0; i < nbInputs + nbOutputs; ++i)
    {
        std::cout << formatToString(inOut[i].format) << ",";
    }
    std::cout << "],[";
    for (int i = 0; i < nbInputs + nbOutputs; ++i)
    {
        std::cout << dataTypeToString(inOut[i].type) << ",";
    }
    std::cout << "]" << std::endl;
#endif
    return res;
}

void AddScalarPlugin::configurePlugin(const DynamicPluginTensorDesc *in, int32_t nbInputs, const DynamicPluginTensorDesc *out, int32_t nbOutputs) noexcept
{
    WHERE_AM_I();
    return;
}

size_t AddScalarPlugin::getWorkspaceSize(const PluginTensorDesc *inputs, int32_t nbInputs, const PluginTensorDesc *outputs, int32_t nbOutputs) const noexcept
{
    WHERE_AM_I();
    return 0;
}

int32_t AddScalarPlugin::enqueue(const PluginTensorDesc *inputDesc, const PluginTensorDesc *outputDesc, const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept
{
    WHERE_AM_I();
    int nElement = 1;
    for (int i = 0; i < inputDesc[0].dims.nbDims; ++i)
    {
        nElement *= inputDesc[0].dims.d[i];
    }
    dim3 grid(CEIL_DIVIDE(nElement, 256), 1, 1), block(256, 1, 1);
    addScalarKernel<<<grid, block, 0, stream>>>(reinterpret_cast<const float *>(inputs[0]), reinterpret_cast<float *>(outputs[0]), m_.scalar, nElement);
    return 0;
}

void AddScalarPlugin::destroy() noexcept
{
    WHERE_AM_I();
    delete this;
    return;
}

int32_t AddScalarPlugin::initialize() noexcept
{
    WHERE_AM_I();
    return 0;
}

void AddScalarPlugin::terminate() noexcept
{
    WHERE_AM_I();
    return;
}

size_t AddScalarPlugin::getSerializationSize() const noexcept
{
    WHERE_AM_I();
    return sizeof(m_);
}

void AddScalarPlugin::serialize(void *buffer) const noexcept
{
    WHERE_AM_I();
    memcpy(buffer, &m_, sizeof(m_));
    return;
}

void AddScalarPlugin::setPluginNamespace(const char *pluginNamespace) noexcept
{
    WHERE_AM_I();
    namespace_ = pluginNamespace;
    return;
}

const char *AddScalarPlugin::getPluginNamespace() const noexcept
{
    WHERE_AM_I();
    return namespace_.c_str();
}

const char *AddScalarPlugin::getPluginType() const noexcept
{
    WHERE_AM_I();
    return PLUGIN_NAME;
}

const char *AddScalarPlugin::getPluginVersion() const noexcept
{
    WHERE_AM_I();
    return PLUGIN_VERSION;
}

void AddScalarPlugin::attachToContext(cudnnContext *contextCudnn, cublasContext *contextCublas, IGpuAllocator *gpuAllocator) noexcept
{
    WHERE_AM_I();
    return;
}

void AddScalarPlugin::detachFromContext() noexcept
{
    WHERE_AM_I();
    return;
}

// class AddScalarPluginCreator
PluginFieldCollection    AddScalarPluginCreator::fc_ {};
std::vector<PluginField> AddScalarPluginCreator::attr_;

AddScalarPluginCreator::AddScalarPluginCreator()
{
    WHERE_AM_I();
    attr_.clear();
    attr_.emplace_back(PluginField("scalar", nullptr, PluginFieldType::kFLOAT32, 1));
    fc_.nbFields = attr_.size();
    fc_.fields   = attr_.data();
}

AddScalarPluginCreator::~AddScalarPluginCreator()
{
    WHERE_AM_I();
}

IPluginV2DynamicExt *AddScalarPluginCreator::createPlugin(const char *name, const PluginFieldCollection *fc) noexcept
{
    WHERE_AM_I();
    float                          scalar = 0;
    std::map<std::string, float *> parameterMap {{"scalar", &scalar}};

    for (int i = 0; i < fc->nbFields; ++i)
    {
        if (parameterMap.find(fc->fields[i].name) != parameterMap.end())
        {
            *parameterMap[fc->fields[i].name] = *reinterpret_cast<const float *>(fc->fields[i].data);
        }
    }
    AddScalarPlugin *pObj = new AddScalarPlugin(name, scalar);
    pObj->setPluginNamespace(namespace_.c_str());
    return pObj;
}

IPluginV2DynamicExt *AddScalarPluginCreator::deserializePlugin(const char *name, const void *serialData, size_t serialLength) noexcept
{
    WHERE_AM_I();
    AddScalarPlugin *pObj = new AddScalarPlugin(name, serialData, serialLength);
    pObj->setPluginNamespace(namespace_.c_str());
    return pObj;
}

void AddScalarPluginCreator::setPluginNamespace(const char *pluginNamespace) noexcept
{
    WHERE_AM_I();
    namespace_ = pluginNamespace;
    return;
}

const char *AddScalarPluginCreator::getPluginNamespace() const noexcept
{
    WHERE_AM_I();
    return namespace_.c_str();
}

const char *AddScalarPluginCreator::getPluginName() const noexcept
{
    WHERE_AM_I();
    return PLUGIN_NAME;
}

const char *AddScalarPluginCreator::getPluginVersion() const noexcept
{
    WHERE_AM_I();
    return PLUGIN_VERSION;
}

const PluginFieldCollection *AddScalarPluginCreator::getFieldNames() noexcept
{
    WHERE_AM_I();
    return &fc_;
}

REGISTER_TENSORRT_PLUGIN(AddScalarPluginCreator);

} // namespace nvinfer1

由于篇幅限制,相应的.h头问题请参考cookbook,通过Plugin的核函数,我们需要生成相应的Plugin.so的库文件,方便后续的Plugin读取和使用。

2.读取Plugin

上一步我们生成了相应的Plugin.so库文件,这一步我们通过调用.so文件,实现对Plugin的注册和配置,最终实现在网络中插入自己手写的Plugin。

Plugin的初始化、注册以及配置的具体流程如下:
在这里插入图片描述
创建完成Plugin后,我们还可以进行序列化和反序列化的操作:

pluginString = plugin.serialize()
plugin = creator.deserialize_plugin(creator.name, pluginString)  # create a plugin by memory of serialized plugin

最后将Plugin插入到Network中:

pluginLayer = network.add_plugin_v2([inputT0], plugin)

详细的代码如下(对应cookbook中的trt-samples-for-hackathon-cn\cookbook\05-Plugin\API\ main.py):

import ctypes
import os
from glob import glob

import numpy as np
import tensorrt as trt
from cuda import cudart

soFile = "./AddScalarPlugin.so"
np.set_printoptions(precision=3, linewidth=200, suppress=True)
np.random.seed(31193)
cudart.cudaDeviceSynchronize()

def getAddScalarPlugin(scalar):
    for c in trt.get_plugin_registry().plugin_creator_list:
        #print(c.name)
        if c.name == "AddScalar":
            parameterList = []
            parameterList.append(trt.PluginField("scalar", np.float32(scalar), trt.PluginFieldType.FLOAT32))
            return c.create_plugin(c.name, trt.PluginFieldCollection(parameterList))
    return None

# os.chdir("/w/gitlab/tensorrt-cookbook/05-Plugin/API/")

# Load default plugin creators
logger = trt.Logger(trt.Logger.ERROR)#创建logger
trt.init_libnvinfer_plugins(logger, '')#初始化plugins

pluginRegistry = trt.get_plugin_registry()#获取plugin注册表
print("Count of default plugin creators = %d" % len(pluginRegistry.plugin_creator_list))#输出plugin注册表中已有的plugin

# Attributions of Plugin Registry
print("pluginRegistry.error_recorder =", pluginRegistry.error_recorder)  # 检查registry是否有错误 ErrorRecorder can be set into EngineInspector, usage of ErrorRecorder refer to 02-API/ErrorRecorder
pluginRegistry.parent_search_enabled = True  # whether search plugin creators in parent directory, default value is True

# Load local plugin creators
for soFile in glob("./*.so"):#读取自己创建的plugin
    if True:  # common method
        ctypes.cdll.LoadLibrary(soFile)
    else:  # use TensorRT API, but there are some problems, do not use this temporarily
        handle = pluginRegistry.load_library(soFile)
        #pluginRegistry.deregister_library(handle)  # deregiste the library
print("Count of total plugin creators = %d" % len(pluginRegistry.plugin_creator_list))  #长度+1 / one more plugin creator "AddScalar" added

#pluginRegistry.deregister_library(?)  # deregiste the library

# print information of all plugin creators
print("TensorRTVersion Namespace PluginVersion Name")
for creator in pluginRegistry.plugin_creator_list:
    print("%4s            %s        %s             %s" % (creator.tensorrt_version, ("\"\"" if creator.plugin_namespace == "" else creator.plugin_namespace), creator.plugin_version, creator.name))

for creator in pluginRegistry.plugin_creator_list:
    if creator.name == "AddScalar" and creator.plugin_version == "1":  # check name and version during selecting plugin

        # print the necessary parameters for creating the plugin
        for i, pluginField in enumerate(creator.field_names):
            print("%2d->%s, %s, %s, %s" % (i, pluginField.name, pluginField.type, pluginField.size, pluginField.data))
            #i=0; pluginField.name=scalar; pluginField.type=PluginFieldType.FLOAT32; pluginField.size= 1;pluginField.data= None;
        # We can registe and deregiste a plugin creator in Plugin Registry, but not required
        #pluginRegistry.deregister_creator(creator)  # deregiste the plugin creator
        #pluginRegistry.register_creator(creator)  # registe the plugin creator again

        # feed the PluginCreator with parameters
        pluginFieldCollection = trt.PluginFieldCollection()#配置pluginField
        pluginField = trt.PluginField("scalar", np.float32(1.0), trt.PluginFieldType.FLOAT32)
        # tensorrt.PluginFieldType: FLOAT16, FLOAT32, FLOAT64, INT8, INT16, INT32, CHAR, DIMS, UNKNOWN
        print(pluginField.name, pluginField.type, pluginField.size, pluginField.data)#scalar PluginFieldType.FLOAT32 1 <capsule object NULL at 0x7fea9d292de0>

        pluginFieldCollection.append(pluginField)  # use like a list
        #pluginFieldCollection.insert(1,pluginField)
        #pluginFieldCollection.extend([pluginField])
        #pluginFieldCollection.clear()
        #pluginFieldCollection.pop(1)
        plugin = creator.create_plugin(creator.name, pluginFieldCollection)  # create a plugin by parameters

        plugin.__class__ = trt.IPluginV2Ext  # change class of plugin from IPluginV2 to IPluginV2Ext, we still do not have IPluginV2Dynamic class

        # methods not work in python API
        # plugin.supports_format(trt.float32, None)  # nvinfer1::TensorFormat::kLINEAR
        #plugin.attach_to_context(None, None)
        #plugin.detach_from_context()
        #plugin.configure_with_format([[2]], [[2]], trt.float32, None, 1)  # nvinfer1::TensorFormat::kLINEAR
        #plugin.configure_plugin([[2]],[[2]],[trt.float32],[trt.float32],[False],[False], None, 1)  # nvinfer1::TensorFormat::kLINEAR
        #plugin.execute_async(1, [None], [None], None, 0)  # address of input / output / workspace memory
        #plugin.initialize()
        #plugin.terminate()
        #plugin.destroy()

        # methods work (but useless) in python API
        print("plugin.plugin_type =", plugin.plugin_type)
        print("plugin.plugin_namespace =", plugin.plugin_namespace)
        print("plugin.plugin_version =", plugin.plugin_version)
        print("plugin.num_outputs =", plugin.num_outputs)
        print("plugin.serialization_size =", plugin.serialization_size)
        print("plugin.tensorrt_version =", plugin.tensorrt_version)
        print("plugin.clone() =", plugin.clone())
        print("plugin.get_output_data_type(0, [trt.float32]) =", plugin.get_output_data_type(0, [trt.float32]))
        print("plugin.get_output_shape(0, [trt.Dims([2])])) =", plugin.get_output_shape(0, [trt.Dims([2])]))  # output is always ((0))?
        print("plugin.get_workspace_size(1) =", plugin.get_workspace_size(1))  # output is always 0?

        pluginString = plugin.serialize()
        plugin = creator.deserialize_plugin(creator.name, pluginString)  # create a plugin by memory of serialized plugin

builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
profile = builder.create_optimization_profile()
config = builder.create_builder_config()

inputT0 = network.add_input("inputT0", trt.float32, [-1])
profile.set_shape(inputT0.name, [1], [2], [4])
config.add_optimization_profile(profile)

pluginLayer = network.add_plugin_v2([inputT0], plugin)
print(pluginLayer.plugin)  # other members and methods refer to 02-API/Layer

print("Finish")

# methods not work
#trt.get_builder_plugin_registry(None)  # nvinfer1::EngineCapability

输出结果:

python3 ./main.py
Count of default plugin creators = 67
pluginRegistry.error_recorder = None
Count of total plugin creators = 68
TensorRTVersion Namespace PluginVersion Name
8601            ""        1             CaskDeconvShaderWeightsTransformerPlugin
8601            ""        1             CaskConvShaderWeightsTransformerPlugin
8601            ""        1             RNNTEncoderPlugin
8601            ""        1             SmallTileGEMM_TRT
8601            ""        1             DLRM_BOTTOM_MLP_TRT
8601            ""        1             CustomQKVToContextPluginDynamic
8601            ""        2             CustomQKVToContextPluginDynamic
8601            ""        3             CustomQKVToContextPluginDynamic
8601            ""        1             CustomSkipLayerNormPluginDynamic
8601            ""        2             CustomSkipLayerNormPluginDynamic
8601            ""        3             CustomSkipLayerNormPluginDynamic
8601            ""        4             CustomSkipLayerNormPluginDynamic
8601            ""        1             SingleStepLSTMPlugin
8601            ""        1             RnRes2FullFusion_TRT
8601            ""        1             RnRes2Br2bBr2c_TRT
8601            ""        2             RnRes2Br2bBr2c_TRT
8601            ""        1             RnRes2Br1Br2c_TRT
8601            ""        2             RnRes2Br1Br2c_TRT
8601            ""        1             GroupNormalizationPlugin
8601            ""        1             CustomGeluPluginDynamic
8601            ""        1             CustomFCPluginDynamic
8601            ""        2             CustomEmbLayerNormPluginDynamic
8601            ""        3             CustomEmbLayerNormPluginDynamic
8601            ""        1             CustomEmbLayerNormPluginDynamic
8601            ""        1             DisentangledAttention_TRT
8601            ""        1             BatchedNMSDynamic_TRT
8601            ""        1             BatchedNMS_TRT
8601            ""        1             BatchTilePlugin_TRT
8601            ""        1             Clip_TRT
8601            ""        1             CoordConvAC
8601            ""        1             CropAndResizeDynamic
8601            ""        1             CropAndResize
8601            ""        1             DecodeBbox3DPlugin
8601            ""        1             DetectionLayer_TRT
8601            ""        1             EfficientNMS_Explicit_TF_TRT
8601            ""        1             EfficientNMS_Implicit_TF_TRT
8601            ""        1             EfficientNMS_ONNX_TRT
8601            ""        1             EfficientNMS_TRT
8601            ""        1             FlattenConcat_TRT
8601            ""        1             GenerateDetection_TRT
8601            ""        1             GridAnchor_TRT
8601            ""        1             GridAnchorRect_TRT
8601            ""        1             InstanceNormalization_TRT
8601            ""        2             InstanceNormalization_TRT
8601            ""        1             LReLU_TRT
8601            ""        1             ModulatedDeformConv2d
8601            ""        1             MultilevelCropAndResize_TRT
8601            ""        1             MultilevelProposeROI_TRT
8601            ""        1             MultiscaleDeformableAttnPlugin_TRT
8601            ""        1             NMSDynamic_TRT
8601            ""        1             NMS_TRT
8601            ""        1             Normalize_TRT
8601            ""        1             PillarScatterPlugin
8601            ""        1             PriorBox_TRT
8601            ""        1             ProposalDynamic
8601            ""        1             ProposalLayer_TRT
8601            ""        1             Proposal
8601            ""        1             PyramidROIAlign_TRT
8601            ""        1             Region_TRT
8601            ""        1             Reorg_TRT
8601            ""        1             ResizeNearest_TRT
8601            ""        1             ROIAlign_TRT
8601            ""        1             RPROI_TRT
8601            ""        1             ScatterND
8601            ""        1             SpecialSlice_TRT
8601            ""        1             Split
8601            ""        1             VoxelGeneratorPlugin
8601            ""        1             AddScalar
 0->scalar, PluginFieldType.FLOAT32, 1, None
scalar PluginFieldType.FLOAT32 1 <capsule object NULL at 0x7fea9d292de0>
plugin.plugin_type = AddScalar
plugin.plugin_namespace =
plugin.plugin_version = 1
plugin.clone() = <tensorrt.tensorrt.IPluginV2Ext object at 0x7fea95c2b4f0>
plugin.get_output_data_type(0, [trt.float32]) = DataType.FLOAT
plugin.get_output_shape(0, [trt.Dims([2])])) = (0)
plugin.get_workspace_size(1) = 0
<tensorrt.tensorrt.IPluginV2 object at 0x7fea95c2c4b0>
Finish

总结

手写Plugin难度较高,尤其是Plugin的核函数部分,建议参考Nvidia的教程模板进行修改,另外针对于Plugin的读取,需要大家了解并掌握详细的读取步骤,实现对Plugin的注册和配置,并最终实现将自己手写的Plugin插入到网络中。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值