Tensorrt plugin编写思路

tensorrt7.0 的文档路径:https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/index.html

tensorrt plugin编写基础类中函数说明

  • 三个重构函数,其中三个重构函数分别的作用为:
    • 其中两个给creator类使用;
    • 第三个给基础类中的clone函数使用;
  • getNbOutputs: 该层返回输出的张量个数,
  • getOutputDimensions:返回输出的张量维度(返回多个张量咋写? 估计会根据index返回不同的Dims结构),
  • configureWithFormat :根据数据个数做出一些调整,反正会传入一个DataType参数,看程序是否需要做一些调整,如果是要实现INT8,datatype在这里配置。
  • initialize: 做一些初始化的工作,有些工作放到了析构函数中做
  • terminate: 跟initialize 对应,有些工作放到了构造函数中做
  • serialize: 把一些用得着的数据写入一个buffer
  • getSerializationSize: 返回serialize时写了多少个字节到buffer了
  • getWorkSpaceSize: 都是直接返回0了
  • enqueue: 根据输入数据进行计算处理得到输出,一般用刀了GPU 加速,所以这个函数会在 .cu文件中实现
  • 另外Plugin类的实现中还需要两个构造函数和一个析构函数
  • 两个构造函数分别完成的工作是 1、通过API调用或者网络结构构造类,函数参数自定义,因为在pluginfactory中自主调用构造。2 通过反序列化完成参数读入进行的plugin构造.对应构造函数plugin(const void* data,int length)
  • clone:This function is called by the Builder prior to initialize() . It provides an opportunity for the layer to make algorithm choices on the basis of its weights, dimensions, and maximum batch size.

creator类说明

  • creator类是用于与tensorrt直接调用相关的接口类;
  • 其会有一些基础函数,包括:
    • getPluginName:获取该plugin的名字
    • getPluginVersion: 获取该plugin的版本号
    • getFieldNames:用于获取layer中的参数;
    • createPlugin:标准接口,会由tensorrt内部进行调用,因此格式固定(const char* name, const PluginFieldCollection* fc);
    • deserializePlugin :用于反序列化,格式固定(const char* name, const void* serialData, size_t serialLength);

调用流程

  1. 初始化***plugincreator,获取pluginname 和 pluginversion;
  2. 启动 plugincreator的constructor函数,并调用plugin的constructor函数,将参数设置如plugin中;
  3. 调用clone函数,对plugin进行clone;
  4. 调用supportFormat函数(确定tensorrt计算的type–kFLOAT和通道的循序–kNCHW)-> 再次调用clone函数 -> plugin constructor -> configurePlugin(用于做参数判定,如果配置的参数不满足要求则直接报错) -> clone -> destroy -> initialize -> destroy;
  5. 序列化过程:getType -> Namespace -> serialization size -> serialize -> terminate -> destroy;
  6. 反序列化: plugincreator constructor -> creator deserialize -> plugin deserialize -> initialize -> enqueue -> terminate -> destroy。

custom plugin layer(with python)

https://www.cnblogs.com/shouhuxianjian/p/10532950.html

attention ocr介绍

文章对attention ocr的全流程进行了讲解,内容非常好,尤其是对与网络的介绍。 整体流程为:encoder+decoder ----------encoder采用CNN+biLSTM模型 ------- decoder采用Attention模型
attention ocr最终回归出来的结果是69(69个字符)1728(2472)
其中包含了lstm,但是lstm在tensorrt的7.0以下版本都没有实现,所以最好是转onnx然后做trt的实现。

plugin编写流程以及注意事项说明(基于tensorflow,pb转uff),本案例以rensorrt source code中batchedNMSPlugin进行介绍

  1. pb转uff时,在create_plugin_note的时候可以设置参数的输入;
    此参数可以通过“plugincreator”接口进行设置并调用:
mPluginAttributes.emplace_back(PluginField("shareLocation", nullptr, PluginFieldType::kINT32, 1));
    const PluginField* fields = fc->fields;
    mClipBoxes = true;

    for (int i = 0; i < fc->nbFields; ++i)
    {
        const char* attrName = fields[i].name;
        if (!strcmp(attrName, "shareLocation"))
        {
            params.shareLocation = *(static_cast<const bool*>(fields[i].data));
        }
    }
  1. 并通过plgin的构造函数将参数传递到plugin函数。
BatchedNMSPlugin* plugin = new BatchedNMSPlugin(params);
  1. plugin函数编写(需注意以下几个函数的编写)。
    a. 多个构造函数,其中一个给clone调用,一个给plugincreator调用并传参,还有一个用于序列化,用于反序列化的函数需要注意,需要将需传递到cu代码参数进行read,有多少个参数就read多少个参数;
BatchedNMSPlugin::BatchedNMSPlugin(const void* data, size_t length)
{
    const char *d = reinterpret_cast<const char*>(data), *a = d;
    param = read<NMSParameters>(d);
    boxesSize = read<int>(d);
    scoresSize = read<int>(d);
    numPriors = read<int>(d);
    mClipBoxes = read<bool>(d);
    ASSERT(d == a + length);
}

b. getNbOutputs需要返回编写的plugin的输出数量,这个需要预先知道并填写;

int BatchedNMSPlugin::getNbOutputs() const
{
    return 4;
}

c. getOutputDimensions函数,可以通过该函数获取输入的参数数量以及每个参数的维度信息

inputs[0].d[0]为输入1维度1的信息;
inputs[1].d[0]为输入2维度1的信息;
inputs[0].nbDims可以获取输入一有多少维;
nbInputDims为输入数量;

此处也可以获取一些全局变量,例如输入数据的维度,可以将其作为全局变量,以备enqueue函数使用。
此处需要预先知道输出的维度是多少并返回,如果只有2为可以使用return DimsHW( ,);输出三维可以使用return DimsCHW( ,,);等
d. enqueue函数是trt inference的主要入口,其中调用的函数需要编写cu文件来实现或者直接通过cuda的标准接口实现。
例如:

    pluginStatus_t status = nmsInference(stream, batchSize, boxesSize, scoresSize, param.shareLocation,
        param.backgroundLabelId, numPriors, param.numClasses, param.topK, param.keepTopK, param.scoreThreshold,
        param.iouThreshold, DataType::kFLOAT, locData, DataType::kFLOAT, confData, keepCount, nmsedBoxes, nmsedScores,
        nmsedClasses, workspace, param.isNormalized, false, mClipBoxes);

该接口需要在plugin/common/kernal.h里面进行定义,并在plugin/common/kernal/底下编写cu code。
e. getSerializationSize用来说明你在序列化写的数据长度或者是你在反序列化读的长度(是一致的);
f. serialize用于序列化,需要写参数。

void BatchedNMSPlugin::serialize(void* buffer) const
{
    char *d = reinterpret_cast<char*>(buffer), *a = d;
    write(d, param);
    write(d, boxesSize);
    write(d, scoresSize);
    write(d, numPriors);
    write(d, mClipBoxes);
    ASSERT(d == a + getSerializationSize());
}

g. configurePlugin和getOutputDimensions的输入数据基本类似,同样可以获取输入参数的维度信息,并设置全局变量;
这个函数相对于getOutputDimensions而言,就是这个函数是必须要实现的且全局变量必须赋予值。
里面也可以对输入数据进行维度的报错。
h.clone函数需要调用plugin的构造并传递全局参数。
以上几个函数是必须要实现的,其他的函数参考示例就行。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值