YOLOv4进行TensorRT推理的时候会使用Mish激活函数,而使用到的mish激活函数没有在TensorRT进行实现。故需要进行实现对应TensorRT插件,故需要进行Mish激活函数的实现。
Mish激活函数定义
def mish_fun(x):
tmp = np.log(1 + np.exp(x))
tmp = np.tanh(tmp)
tmp = tmp * x
return tmp
如上所示激活函数的表达式,该激活还是比较复杂的,需要实现对应的tensorRT插件
tensorRT 插件
使用C++ API添加自定义图层
您可以通过从TensorRT的插件基类之一派生来实现自定义层。
从插件的一个基类派生插件类。它们在支持具有不同类型/格式的I/O或具有动态形状的网络方面具有不同的表达能力。下表总结了基类,按表达性从低到高的顺序排列。
注意:如果插件是用于一般用途,请提供FP32实现,以便允许它在任何网络上正常运行。
Introduced in TensorRT version? | Mixed I/O formats/types | Dynamic shapes? | Supports implicit/explicit batch mode? | |
---|---|---|---|---|
IPluginV2Ext | 5.1 | Limited | No | Both implicit and explicit batch modes |
IPluginV2IOExt | 6.0.1 | General | No | Both implicit and explicit batch modes |
IPluginV2DynamicExt | 6.0.1 | General | Yes | Explicit batch mode only |
为了在网络中使用插件,您必须首先在TensorRT的PluginRegistry(C++,Python)中注册它。不是直接注册插件,而是注册插件的工厂类的实例,从PluginCreator(C++,Python)派生。plugin creator类还提供了有关插件的其他信息:其名称、版本和插件字段参数。
有两种方法可以在注册表中注册插件:
TensorRT提供了一个宏REGISTER_TENSORT_PLUGIN,用于在注册表中静态注册插件创建者。请注意,NREGISTER_TENSORT_PLUGI始终在默认名称空间(“”)下注册创建者。
通过创建您自己的入口点(类似于initLibNvInferPlugins)并在插件注册表上调用registerCreator来registerCreator。这比静态注册更好,因为它提供了潜在的更低的内存占用,并允许插件在唯一的名称空间下注册。这确保了在不同插件库之间的构建时间期间没有名称冲突。
调用IPluginCreator::createPlugin()返回一个IPluginV2类型的插件对象。您可以使用addPluginV2()将插件添加到TensorRT网络,这将使用给定的插件创建网络层。
例如,您可以向网络添加插件层,如下所示:
// Look up the plugin in the registry
auto creator = getPluginRegistry()->getPluginCreator(pluginName, pluginVersion);
const PluginFieldCollection* pluginFC = creator->getFieldNames();
// Populate the fields parameters for the plugin layer
// PluginFieldCollection *pluginData = parseAndFillFields(pluginFC, layerFields);
// Create the plugin object using the layerName and the plugin meta data
IPluginV2 *pluginObj = creator->createPlugin(layerName, pluginData);
// Add the plugin to the TensorRT network
auto layer = network.addPluginV2(&inputs[0], int(inputs.size()), pluginObj);
… (build rest of the network and serialize engine)
// Destroy the plugin object
pluginObj->destroy()
… (free allocated pluginData)
注意:前面描述的createPlugin方法在堆上创建了一个新的插件对象,并返回一个指向它的指针。确保销毁pluginObj,如前所示,以避免内存泄漏。
在序列化期间,TensorRT引擎在内部存储所有IPluginV2类型插件的插件类型、插件版本和命名空间(如果存在)。在反序列化过程中,TensorRT从插件注册表中查找插件创建者,并调用IPluginCreator::deserializePlugin()。当引擎被删除时,引擎通过调用IPluginV2::destroy()方法销毁在引擎构建期间创建的插件对象的克隆。您有责任确保您创建的插件对象在添加到网络后被释
注:
不要序列化所有插件参数:只有那些插件在运行时正确运行所需的。可以省略生成时间参数。
以相同的顺序序列化和反序列化插件参数。在反序列化过程中,验证插件参数是否初始化为默认值或反序列化值。未初始化的参数会导致未定义的行为。
如果您是汽车安全用户,则必须调用getSafePluginRegistry()而不是getPluginRegistry()。还必须使用REGISTER_SAFE_TENSORT_PLUGIN宏,而不是REGISTER_TENSORT_PLUGIN。
关键类说明
IPluginV2Ext
用户实现层的插件创建类。
virtual nvinfer1::DataType | getOutputDataType (int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0 |
返回请求索引处插件输出的DataType, More... | |
virtual bool | isOutputBroadcastAcrossBatch (int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0 |
如果输出张量在批处理中广播,则返回true。. More... | |
virtual bool | canBroadcastInputAcrossBatch (int32_t inputIndex) const noexcept=0 |
如果插件可以使用跨批广播的输入而无需复制,则返回true. More... | |
virtual void | configurePlugin (Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0 |
使用输入和输出数据类型配置图层. More... | |
IPluginV2Ext ()=default | |
~IPluginV2Ext () override=default | |
virtual void | attachToContext (cudnnContext *, cublasContext *, IGpuAllocator *) noexcept |
将插件对象附加到执行上下文,并授予插件对某些上下文资源的访问权限 More... | |
virtual void | detachFromContext () noexcept |
从执行上下文中分离插件对象. More... | |
IPluginV2Ext * | clone () const noexcept override=0 |
克隆插件对象。这也会复制内部插件参数,并返回一个带有这些参数的新插件对象。如果源插件预先配置了configurePlugin(),返回的对象也应该是预先配置好的。返回的对象应该允许连接到上下文() 克隆的插件对象可以与源对象共享相同的每引擎不可变资源(例如,权重)(例如,经由引用计数)以避免重复 | |
Public Member Functions inherited from nvinfer1::IPluginV2 | |
virtual AsciiChar const * | getPluginType () const noexcept=0 |
返回插件类型。应与相应插件创建者返回的插件名称匹配。 More... | |
virtual AsciiChar const * | getPluginVersion () const noexcept=0 |
返回插件版本。应该与相应插件创建者返回的插件版本相匹配r. More... | |
virtual int32_t | getNbOutputs () const noexcept=0 |
获取层的输出数量 More... | |
virtual Dims | getOutputDimensions (int32_t index, Dims const *inputs, int32_t nbInputDims) noexcept=0 |
获取输出张量的维度. More... | |
virtual bool | supportsFormat (DataType type, PluginFormat format) const noexcept=0 |
检查格式支持 More... | |
virtual int32_t | initialize () noexcept=0 |
I初始化要执行的层。这在引擎创建时调用More... | |
virtual void | terminate () noexcept=0 |
释放插件层初始化过程中获取的资源。这叫engine被毁 More... | |
virtual size_t | getWorkspaceSize (int32_t maxBatchSize) const noexcept=0 |
查找图层所需的工作空间大小. More... | |
virtual int32_t | enqueue (int32_t batchSize, void const *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept=0 |
执行图层 More... | |
virtual size_t | getSerializationSize () const noexcept=0 |
查找所需的序列化缓冲区的大小 More... | |
virtual void | serialize (void *buffer) const noexcept=0 |
序列化图层r. More... | |
virtual void | destroy () noexcept=0 |
销毁插件对象。这将在网络、构建器或引擎被销毁时调用. More... | |
virtual void | setPluginNamespace (AsciiChar const *pluginNamespace) noexcept=0 |
设置此插件对象所属的命名空间。理想情况下,同一插件库中的所有插件对象都应该具有相同的命名空间. More... | |
virtual AsciiChar const * | getPluginNamespace () const noexcept=0 |
返回插件对象的命名空间. More... |
Protected Member Functions | |
int32_t | getTensorRTVersion () const noexcept override |
返回构建此插件的API版本。TensorRT保留的高位字节,用于区分此插件和IPluginV2IPluginV2. More... | |
void | configureWithFormat (Dims const *, int32_t, Dims const *, int32_t, DataType, PluginFormat, int32_t) noexcept override |
派生类不应该实现这个。在C++11 API中,它将被override final. More... |
IPluginCreator
用户实现层的插件创建器类。
virtual int32_t | getTensorRTVersion () const noexcept |
返回插件创建者编译时使用的API版本. More... | |
virtual AsciiChar const * | getPluginName () const noexcept=0 |
返回插件名称。. More... | |
virtual AsciiChar const * | getPluginVersion () const noexcept=0 |
返回插件版本. More... | |
virtual PluginFieldCollection const * | getFieldNames () noexcept=0 |
返回需要传递给createPlugin的字段列表. More... | |
virtual IPluginV2 * | createPlugin (AsciiChar const *name, PluginFieldCollection const *fc) noexcept=0 |
返回一个插件对象。错误时返回nullptr. More... | |
virtual IPluginV2 * | deserializePlugin (AsciiChar const *name, void const *serialData, size_t serialLength) noexcept=0 |
在插件层的反序列化过程中调用。返回插件对象 | |
virtual void | setPluginNamespace (AsciiChar const *pluginNamespace) noexcept=0 |
根据插件所属的插件库设置插件创建者的命名空间。这可以在注册插件创建者时设置。 More... | |
virtual AsciiChar const * | getPluginNamespace () const noexcept=0 |
返回插件创建者对象的命名空间。 More... | |
IPluginCreator ()=default | |
virtual | ~IPluginCreator ()=default |
Mish激活函数
#ifndef _MISH_PLUGIN_H
#define _MISH_PLUGIN_H
#include <string>
#include <vector>
#include "NvInfer.h"
namespace nvinfer1
{
class MishPlugin: public IPluginV2IOExt
{
public:
// 显示的构造函数
explicit MishPlugin();
// 构造函数
MishPlugin(const void* data, size_t length);
~MishPlugin();
// 返回plugin的输出数量
int getNbOutputs() const override
{
return 1;
}
// 输出张量的输出的维度
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override;
// 初始化要执行的层。这在引擎创建时调用
int initialize() override;
// 释放插件层初始化过程中获取的资源
virtual void terminate() override {};
// 获得该层工作空间的大小
virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
// 执行该的层的处理
virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
// 查找所需的序列化缓冲区的大小
virtual size_t getSerializationSize() const override;
// 进行对应的序列化
virtual void serialize(void* buffer) const override;
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const override {
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
}
const char* getPluginType() const override;
const char* getPluginVersion() const override;
// 销毁插件对象。这将在网络、构建器或引擎被销毁时调用
void destroy() override;
// 克隆对象
IPluginV2IOExt* clone() const override;
// 这是对象所属的命名空间
void setPluginNamespace(const char* pluginNamespace) override;
// 返回对象的命名空间
const char* getPluginNamespace() const override;
// 返回对象输出时间
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
// 是否进行舒畅张量在批处理中广播
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override;
// 如果插件可以使用跨批广播的输入而无需刻意的复制
bool canBroadcastInputAcrossBatch(int inputIndex) const override;
// 将插件对象附加到执行上下文,
void attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
// 使用输入和输出数据类型配置网络层
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
// 从执行上下文中分离插件对象.
void detachFromContext() override;
int input_size_;
private:
void forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize = 1);
int thread_count_ = 256;
const char* mPluginNamespace;
};
class MishPluginCreator : public IPluginCreator
{
public:
// 构造函数
MishPluginCreator();
// 析构函数
~MishPluginCreator() override = default;
// 获得插件的名字
const char* getPluginName() const override;
// 获得插件的版本
const char* getPluginVersion() const override;
// 返回需要传递给createPlugin的字段列表
const PluginFieldCollection* getFieldNames() override;
// 返回一个对应的插件对象
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
// 在插件层的反序列化过程中调用。返回插件对象
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
// 根据插件所属的插件库设置插件创建者的命名空间
void setPluginNamespace(const char* libNamespace) override
{
mNamespace = libNamespace;
}
// 返回插件创建者对象的命名空间
const char* getPluginNamespace() const override
{
return mNamespace.c_str();
}
private:
std::string mNamespace;
static PluginFieldCollection mFC;
static std::vector<PluginField> mPluginAttributes;
};
REGISTER_TENSORRT_PLUGIN(MishPluginCreator);
};
#endif
#include <cmath>
#include <stdio.h>
#include <cassert>
#include <iostream>
#include "mish.h"
namespace nvinfer1
{
MishPlugin::MishPlugin()
{
}
MishPlugin::~MishPlugin()
{
}
// create the plugin at runtime from a byte stream
MishPlugin::MishPlugin(const void* data, size_t length)
{
assert(length == sizeof(input_size_));
input_size_ = *reinterpret_cast<const int*>(data);
}
void MishPlugin::serialize(void* buffer) const
{
*reinterpret_cast<int*>(buffer) = input_size_;
}
size_t MishPlugin::getSerializationSize() const
{
return sizeof(input_size_);
}
int MishPlugin::initialize()
{
return 0;
}
Dims MishPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
{
assert(nbInputDims == 1);
assert(index == 0);
input_size_ = inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2];
// Output dimensions
return Dims3(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
}
// Set plugin namespace
void MishPlugin::setPluginNamespace(const char* pluginNamespace)
{
mPluginNamespace = pluginNamespace;
}
const char* MishPlugin::getPluginNamespace() const
{
return mPluginNamespace;
}
// Return the DataType of the plugin output at the requested index
DataType MishPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
{
return DataType::kFLOAT;
}
// Return true if output tensor is broadcast across a batch.
bool MishPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
{
return false;
}
// Return true if plugin can use input that is broadcast across batch without replication.
bool MishPlugin::canBroadcastInputAcrossBatch(int inputIndex) const
{
return false;
}
void MishPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput)
{
}
// Attach the plugin object to an execution context and grant the plugin the access to some context resource.
void MishPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
{
}
// Detach the plugin object from its execution context.
void MishPlugin::detachFromContext() {}
const char* MishPlugin::getPluginType() const
{
return "Mish_TRT";
}
const char* MishPlugin::getPluginVersion() const
{
return "1";
}
void MishPlugin::destroy()
{
delete this;
}
// Clone the plugin
IPluginV2IOExt* MishPlugin::clone() const
{
MishPlugin *p = new MishPlugin();
p->input_size_ = input_size_;
p->setPluginNamespace(mPluginNamespace);
return p;
}
__device__ float tanh_activate_kernel(float x){return (2/(1 + expf(-2*x)) - 1);}
__device__ float softplus_kernel(float x, float threshold = 20) {
if (x > threshold) return x; // too large
else if (x < -threshold) return expf(x); // too small
return logf(expf(x) + 1);
}
__global__ void mish_kernel(const float *input, float *output, int num_elem) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx >= num_elem) return;
//float t = exp(input[idx]);
//if (input[idx] > 20.0) {
// t *= t;
// output[idx] = (t - 1.0) / (t + 1.0);
//} else {
// float tt = t * t;
// output[idx] = (tt + 2.0 * t) / (tt + 2.0 * t + 2.0);
//}
//output[idx] *= input[idx];
output[idx] = input[idx] * tanh_activate_kernel(softplus_kernel(input[idx]));
}
void MishPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
int block_size = thread_count_;
int grid_size = (input_size_ * batchSize + block_size - 1) / block_size;
mish_kernel<<<grid_size, block_size>>>(inputs[0], output, input_size_ * batchSize);
}
int MishPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
{
//assert(batchSize == 1);
//GPU
//CUDA_CHECK(cudaStreamSynchronize(stream));
forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize);
return 0;
}
PluginFieldCollection MishPluginCreator::mFC{};
std::vector<PluginField> MishPluginCreator::mPluginAttributes;
MishPluginCreator::MishPluginCreator()
{
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char* MishPluginCreator::getPluginName() const
{
return "Mish_TRT";
}
const char* MishPluginCreator::getPluginVersion() const
{
return "1";
}
const PluginFieldCollection* MishPluginCreator::getFieldNames()
{
return &mFC;
}
IPluginV2IOExt* MishPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
{
MishPlugin* obj = new MishPlugin();
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
IPluginV2IOExt* MishPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
{
// This object will be deleted when the network is destroyed, which will
// call MishPlugin::destroy()
MishPlugin* obj = new MishPlugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
}
将该对象添加到模型中
Weights emptywts{DataType::kFLOAT, nullptr, 0};
//卷积层处理
//!
//! \brief Add a multi-dimension convolution layer to the network.
//!
//! \param The ininputput tensor to the convolution.
//! \param nbOutputMaps The number of output feature maps for the convolution.
//! \param kernelSize The multi-dimensions of the convolution kernel.
//! \param kernelWeights The kernel weights for the convolution.
//! \param biasWeights The optional bias weights for the convolution.
// IConvolutionLayer* addConvolutionNd(
// ITensor& input, int32_t nbOutputMaps, Dims kernelSize, Weights kernelWeights, Weights biasWeights) noexcept
// {
// return mImpl->addConvolutionNd(input, nbOutputMaps, kernelSize, kernelWeights, biasWeights);
// }
IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{ksize, ksize}, weightMap["module_list." + std::to_string(linx) + ".Conv2d.weight"], emptywts);
assert(conv1);
// 设置对应参数
conv1->setStrideNd(DimsHW{s, s});
conv1->setPaddingNd(DimsHW{p, p});
// 设置对应的批量归一化数据
IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), "module_list." + std::to_string(linx) + ".BatchNorm2d", 1e-4);
// 创建对应的mish激活函数
auto creator = getPluginRegistry()->getPluginCreator("Mish_TRT", "1");
const PluginFieldCollection* pluginData = creator->getFieldNames();
IPluginV2 *pluginObj = creator->createPlugin(("mish" + std::to_string(linx)).c_str(), pluginData);
ITensor* inputTensors[] = {bn1->getOutput(0)};
auto mish = network->addPluginV2(&inputTensors[0], 1, *pluginObj);
return mish;