TensorRT学习(实战-自定义算子)

YOLOv4进行TensorRT推理的时候会使用Mish激活函数,而使用到的mish激活函数没有在TensorRT进行实现。故需要进行实现对应TensorRT插件,故需要进行Mish激活函数的实现。

Mish激活函数定义

def mish_fun(x):
    tmp = np.log(1 + np.exp(x))
    tmp = np.tanh(tmp)
    tmp = tmp * x
    return tmp

如上所示激活函数的表达式,该激活还是比较复杂的,需要实现对应的tensorRT插件

tensorRT 插件

使用C++ API添加自定义图层
您可以通过从TensorRT的插件基类之一派生来实现自定义层。
从插件的一个基类派生插件类。它们在支持具有不同类型/格式的I/O或具有动态形状的网络方面具有不同的表达能力。下表总结了基类,按表达性从低到高的顺序排列。
注意:如果插件是用于一般用途,请提供FP32实现,以便允许它在任何网络上正常运行。

Table 3. Base Classes, Ordered from Least Expressive to Most Expressive
Introduced in TensorRT version?Mixed I/O formats/typesDynamic shapes?Supports implicit/explicit batch mode?
IPluginV2Ext5.1LimitedNoBoth implicit and explicit batch modes
IPluginV2IOExt6.0.1GeneralNoBoth implicit and explicit batch modes
IPluginV2DynamicExt6.0.1GeneralYesExplicit batch mode only

为了在网络中使用插件,您必须首先在TensorRT的PluginRegistry(C++,Python)中注册它。不是直接注册插件,而是注册插件的工厂类的实例,从PluginCreator(C++,Python)派生。plugin creator类还提供了有关插件的其他信息:其名称、版本和插件字段参数。

有两种方法可以在注册表中注册插件:

TensorRT提供了一个宏REGISTER_TENSORT_PLUGIN,用于在注册表中静态注册插件创建者。请注意,NREGISTER_TENSORT_PLUGI始终在默认名称空间(“”)下注册创建者。
通过创建您自己的入口点(类似于initLibNvInferPlugins)并在插件注册表上调用registerCreator来registerCreator。这比静态注册更好,因为它提供了潜在的更低的内存占用,并允许插件在唯一的名称空间下注册。这确保了在不同插件库之间的构建时间期间没有名称冲突。
用IPluginCreator::createPlugin()返回一个IPluginV2类型的插件对象。您可以使用addPluginV2()将插件添加到TensorRT网络,这将使用给定的插件创建网络层。

例如,您可以向网络添加插件层,如下所示:

// Look up the plugin in the registry
auto creator = getPluginRegistry()->getPluginCreator(pluginName, pluginVersion);
const PluginFieldCollection* pluginFC = creator->getFieldNames();
// Populate the fields parameters for the plugin layer 
// PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields); 
// Create the plugin object using the layerName and the plugin meta data
IPluginV2 *pluginObj = creator->createPlugin(layerName, pluginData);
// Add the plugin to the TensorRT network 
auto layer = network.addPluginV2(&inputs[0], int(inputs.size()), pluginObj);
… (build rest of the network and serialize engine)
// Destroy the plugin object
pluginObj->destroy()
… (free allocated pluginData)


注意:前面描述的createPlugin方法在堆上创建了一个新的插件对象,并返回一个指向它的指针。确保销毁pluginObj,如前所示,以避免内存泄漏。
在序列化期间,TensorRT引擎在内部存储所有IPluginV2类型插件的插件类型、插件版本和命名空间(如果存在)。在反序列化过程中,TensorRT从插件注册表中查找插件创建者,并调用IPluginCreator::deserializePlugin()。当引擎被删除时,引擎通过调用IPluginV2::destroy()方法销毁在引擎构建期间创建的插件对象的克隆。您有责任确保您创建的插件对象在添加到网络后被释

注:

不要序列化所有插件参数:只有那些插件在运行时正确运行所需的。可以省略生成时间参数。
以相同的顺序序列化和反序列化插件参数。在反序列化过程中,验证插件参数是否初始化为默认值或反序列化值。未初始化的参数会导致未定义的行为。
如果您是汽车安全用户,则必须调用getSafePluginRegistry()而不是getPluginRegistry()。还必须使用REGISTER_SAFE_TENSORT_PLUGIN宏,而不是REGISTER_TENSORT_PLUGIN。

关键类说明

IPluginV2Ext

用户实现层的插件创建类。

virtual nvinfer1::DataType getOutputDataType (int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0
返回请求索引处插件输出的DataType, More...
virtual bool isOutputBroadcastAcrossBatch (int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0
如果输出张量在批处理中广播,则返回true。. More...
virtual bool canBroadcastInputAcrossBatch (int32_t inputIndex) const noexcept=0
如果插件可以使用跨批广播的输入而无需复制,则返回true. More...
virtual void configurePlugin (Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0
使用输入和输出数据类型配置图层. More...
IPluginV2Ext ()=default
~IPluginV2Ext () override=default
virtual void attachToContext (cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
将插件对象附加到执行上下文,并授予插件对某些上下文资源的访问权限 More...
virtual void detachFromContext () noexcept
从执行上下文中分离插件对象. More...
IPluginV2Ext * clone () const noexcept override=0

克隆插件对象。这也会复制内部插件参数,并返回一个带有这些参数的新插件对象。如果源插件预先配置了configurePlugin(),返回的对象也应该是预先配置好的。返回的对象应该允许连接到上下文() 克隆的插件对象可以与源对象共享相同的每引擎不可变资源(例如,权重)(例如,经由引用计数)以避免重复

 Public Member Functions inherited from nvinfer1::IPluginV2
virtual AsciiChar const * getPluginType () const noexcept=0
返回插件类型。应与相应插件创建者返回的插件名称匹配。 More...
virtual AsciiChar const * getPluginVersion () const noexcept=0
返回插件版本。应该与相应插件创建者返回的插件版本相匹配r. More...
virtual int32_t getNbOutputs () const noexcept=0
获取层的输出数量 More...
virtual Dims getOutputDimensions (int32_t index, Dims const *inputs, int32_t nbInputDims) noexcept=0
获取输出张量的维度. More...
virtual bool supportsFormat (DataType type, PluginFormat format) const noexcept=0
检查格式支持 More...
virtual int32_t initialize () noexcept=0
I初始化要执行的层。这在引擎创建时调用More...
virtual void terminate () noexcept=0
释放插件层初始化过程中获取的资源。这叫engine被毁 More...
virtual size_t getWorkspaceSize (int32_t maxBatchSize) const noexcept=0
    查找图层所需的工作空间大小. More...
virtual int32_t enqueue (int32_t batchSize, void const *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept=0
执行图层 More...
virtual size_t getSerializationSize () const noexcept=0
    查找所需的序列化缓冲区的大小 More...
virtual void serialize (void *buffer) const noexcept=0
    序列化图层r. More...
virtual void destroy () noexcept=0
    销毁插件对象。这将在网络、构建器或引擎被销毁时调用. More...
virtual void setPluginNamespace (AsciiChar const *pluginNamespace) noexcept=0
设置此插件对象所属的命名空间。理想情况下,同一插件库中的所有插件对象都应该具有相同的命名空间. More...
virtual AsciiChar const * getPluginNamespace () const noexcept=0
返回插件对象的命名空间. More...

Protected Member Functions

int32_t getTensorRTVersion () const noexcept override
返回构建此插件的API版本。TensorRT保留的高位字节,用于区分此插件和IPluginV2IPluginV2More...
void configureWithFormat (Dims const *, int32_t, Dims const *, int32_t, DataTypePluginFormat, int32_t) noexcept override
派生类不应该实现这个。在C++11 API中,它将被override final. More...

IPluginCreator

用户实现层的插件创建器类。

virtual int32_t getTensorRTVersion () const noexcept
返回插件创建者编译时使用的API版本. More...
virtual AsciiChar const * getPluginName () const noexcept=0
返回插件名称。. More...
virtual AsciiChar const * getPluginVersion () const noexcept=0
    返回插件版本. More...
virtual PluginFieldCollection const * getFieldNames () noexcept=0
返回需要传递给createPlugin的字段列表. More...
virtual IPluginV2 * createPlugin (AsciiChar const *name, PluginFieldCollection const *fc) noexcept=0
返回一个插件对象。错误时返回nullptr. More...
virtual IPluginV2 * deserializePlugin (AsciiChar const *name, void const *serialData, size_t serialLength) noexcept=0
    在插件层的反序列化过程中调用。返回插件对象
virtual void setPluginNamespace (AsciiChar const *pluginNamespace) noexcept=0
根据插件所属的插件库设置插件创建者的命名空间。这可以在注册插件创建者时设置。 More...
virtual AsciiChar const * getPluginNamespace () const noexcept=0
返回插件创建者对象的命名空间。 More...
IPluginCreator ()=default
virtual ~IPluginCreator ()=default

Mish激活函数

#ifndef _MISH_PLUGIN_H
#define _MISH_PLUGIN_H

#include <string>
#include <vector>
#include "NvInfer.h"

namespace nvinfer1
{
    class MishPlugin: public IPluginV2IOExt
    {
        public:
            // 显示的构造函数
            explicit MishPlugin();
            // 构造函数
            MishPlugin(const void* data, size_t length);
            ~MishPlugin();
            // 返回plugin的输出数量
            int getNbOutputs() const override
            {
                return 1;
            }
            // 输出张量的输出的维度
            Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
            // 初始化要执行的层。这在引擎创建时调用
            int initialize() override;
            // 释放插件层初始化过程中获取的资源
            virtual void terminate() override {};
            // 获得该层工作空间的大小
            virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
            // 执行该的层的处理
            virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
            //  查找所需的序列化缓冲区的大小
            virtual size_t getSerializationSize() const override;
            // 进行对应的序列化
            virtual void serialize(void* buffer) const override;

            bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
                return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
            }

            const char* getPluginType() const override;

            const char* getPluginVersion() const override;
            //   销毁插件对象。这将在网络、构建器或引擎被销毁时调用
           void destroy() override;
            // 克隆对象
            IPluginV2IOExt* clone() const override;
            // 这是对象所属的命名空间
            void setPluginNamespace(const char* pluginNamespace) override;
            // 返回对象的命名空间
            const char* getPluginNamespace() const override;
            // 返回对象输出时间
            DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
            // 是否进行舒畅张量在批处理中广播
            bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
            // 如果插件可以使用跨批广播的输入而无需刻意的复制
            bool canBroadcastInputAcrossBatch(int inputIndex) const override;
            // 将插件对象附加到执行上下文,
            void attachToContext(
                    cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
            // 使用输入和输出数据类型配置网络层
            void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
            // 从执行上下文中分离插件对象.
            void detachFromContext() override;

            int input_size_;
        private:
            void forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize = 1);
            int thread_count_ = 256;
            const char* mPluginNamespace;
    };

    class MishPluginCreator : public IPluginCreator
    {
        public:
            // 构造函数
            MishPluginCreator();
            // 析构函数
            ~MishPluginCreator() override = default;
            // 获得插件的名字
            const char* getPluginName() const override;
            // 获得插件的版本
            const char* getPluginVersion() const override;
            // 返回需要传递给createPlugin的字段列表
            const PluginFieldCollection* getFieldNames() override;
            // 返回一个对应的插件对象
            IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
            // 在插件层的反序列化过程中调用。返回插件对象
            IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
            // 根据插件所属的插件库设置插件创建者的命名空间
            void setPluginNamespace(const char* libNamespace) override
            {
                mNamespace = libNamespace;
            }
            // 返回插件创建者对象的命名空间
            const char* getPluginNamespace() const override
            {
                return mNamespace.c_str();
            }

        private:
            std::string mNamespace;
            static PluginFieldCollection mFC;
            static std::vector<PluginField> mPluginAttributes;
    };
    REGISTER_TENSORRT_PLUGIN(MishPluginCreator);
};
#endif 

#include <cmath>
#include <stdio.h>
#include <cassert>
#include <iostream>
#include "mish.h"

namespace nvinfer1
{
    MishPlugin::MishPlugin()
    {
    }

    MishPlugin::~MishPlugin()
    {
    }

    // create the plugin at runtime from a byte stream
    MishPlugin::MishPlugin(const void* data, size_t length)
    {
        assert(length == sizeof(input_size_));
        input_size_ = *reinterpret_cast<const int*>(data);
    }

    void MishPlugin::serialize(void* buffer) const
    {
        *reinterpret_cast<int*>(buffer) = input_size_;
    }

    size_t MishPlugin::getSerializationSize() const
    {  
        return sizeof(input_size_);
    }

    int MishPlugin::initialize()
    { 
        return 0;
    }

    Dims MishPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
    {
        assert(nbInputDims == 1);
        assert(index == 0);
        input_size_ = inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2];
        // Output dimensions
        return Dims3(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
    }

    // Set plugin namespace
    void MishPlugin::setPluginNamespace(const char* pluginNamespace)
    {
        mPluginNamespace = pluginNamespace;
    }

    const char* MishPlugin::getPluginNamespace() const
    {
        return mPluginNamespace;
    }

    // Return the DataType of the plugin output at the requested index
    DataType MishPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
    {
        return DataType::kFLOAT;
    }

    // Return true if output tensor is broadcast across a batch.
    bool MishPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
    {
        return false;
    }

    // Return true if plugin can use input that is broadcast across batch without replication.
    bool MishPlugin::canBroadcastInputAcrossBatch(int inputIndex) const
    {
        return false;
    }

    void MishPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput)
    {
    }

    // Attach the plugin object to an execution context and grant the plugin the access to some context resource.
    void MishPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
    {
    }

    // Detach the plugin object from its execution context.
    void MishPlugin::detachFromContext() {}

    const char* MishPlugin::getPluginType() const
    {
        return "Mish_TRT";
    }

    const char* MishPlugin::getPluginVersion() const
    {
        return "1";
    }

    void MishPlugin::destroy()
    {
        delete this;
    }

    // Clone the plugin
    IPluginV2IOExt* MishPlugin::clone() const
    {
        MishPlugin *p = new MishPlugin();
        p->input_size_ = input_size_;
        p->setPluginNamespace(mPluginNamespace);
        return p;
    }

    __device__ float tanh_activate_kernel(float x){return (2/(1 + expf(-2*x)) - 1);}

    __device__ float softplus_kernel(float x, float threshold = 20) {
        if (x > threshold) return x;                // too large
        else if (x < -threshold) return expf(x);    // too small
        return logf(expf(x) + 1);
    }

    __global__ void mish_kernel(const float *input, float *output, int num_elem) {

        int idx = threadIdx.x + blockDim.x * blockIdx.x;
        if (idx >= num_elem) return;

        //float t = exp(input[idx]);
        //if (input[idx] > 20.0) {
        //    t *= t;
        //    output[idx] = (t - 1.0) / (t + 1.0);
        //} else {
        //    float tt = t * t;
        //    output[idx] = (tt + 2.0 * t) / (tt + 2.0 * t + 2.0);
        //}
        //output[idx] *= input[idx];
        output[idx] = input[idx] * tanh_activate_kernel(softplus_kernel(input[idx]));
    }

    void MishPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
        int block_size = thread_count_;
        int grid_size = (input_size_ * batchSize + block_size - 1) / block_size;
        mish_kernel<<<grid_size, block_size>>>(inputs[0], output, input_size_ * batchSize);
    }

    int MishPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
    {
        //assert(batchSize == 1);
        //GPU
        //CUDA_CHECK(cudaStreamSynchronize(stream));
        forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize);
        return 0;
    }

    PluginFieldCollection MishPluginCreator::mFC{};
    std::vector<PluginField> MishPluginCreator::mPluginAttributes;

    MishPluginCreator::MishPluginCreator()
    {
        mPluginAttributes.clear();

        mFC.nbFields = mPluginAttributes.size();
        mFC.fields = mPluginAttributes.data();
    }

    const char* MishPluginCreator::getPluginName() const
    {
            return "Mish_TRT";
    }

    const char* MishPluginCreator::getPluginVersion() const
    {
            return "1";
    }

    const PluginFieldCollection* MishPluginCreator::getFieldNames()
    {
            return &mFC;
    }

    IPluginV2IOExt* MishPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
    {
        MishPlugin* obj = new MishPlugin();
        obj->setPluginNamespace(mNamespace.c_str());
        return obj;
    }

    IPluginV2IOExt* MishPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
    {
        // This object will be deleted when the network is destroyed, which will
        // call MishPlugin::destroy()
        MishPlugin* obj = new MishPlugin(serialData, serialLength);
        obj->setPluginNamespace(mNamespace.c_str());
        return obj;
    }

}

将该对象添加到模型中

Weights emptywts{DataType::kFLOAT, nullptr, 0};
    //卷积层处理
        //!
    //! \brief Add a multi-dimension convolution layer to the network.
    //!
    //! \param  The ininputput tensor to the convolution.
    //! \param nbOutputMaps The number of output feature maps for the convolution.
    //! \param kernelSize The multi-dimensions of the convolution kernel.
    //! \param kernelWeights The kernel weights for the convolution.
    //! \param biasWeights The optional bias weights for the convolution.
    // IConvolutionLayer* addConvolutionNd(
    //     ITensor& input, int32_t nbOutputMaps, Dims kernelSize, Weights kernelWeights, Weights biasWeights) noexcept
    // {
    //     return mImpl->addConvolutionNd(input, nbOutputMaps, kernelSize, kernelWeights, biasWeights);
    // }
    IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{ksize, ksize}, weightMap["module_list." + std::to_string(linx) + ".Conv2d.weight"], emptywts);
    assert(conv1);
    // 设置对应参数
    conv1->setStrideNd(DimsHW{s, s});
    conv1->setPaddingNd(DimsHW{p, p});

    // 设置对应的批量归一化数据
    IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), "module_list." + std::to_string(linx) + ".BatchNorm2d", 1e-4);

    // 创建对应的mish激活函数
    auto creator = getPluginRegistry()->getPluginCreator("Mish_TRT", "1");
    const PluginFieldCollection* pluginData = creator->getFieldNames();
    IPluginV2 *pluginObj = creator->createPlugin(("mish" + std::to_string(linx)).c_str(), pluginData);
    ITensor* inputTensors[] = {bn1->getOutput(0)};
    auto mish = network->addPluginV2(&inputTensors[0], 1, *pluginObj);
    return mish;
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值