tensorrt7.0 的文档路径:https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/index.html
tensorrt plugin编写基础类中函数说明
- 三个重构函数,其中三个重构函数分别的作用为:
- 其中两个给creator类使用;
- 第三个给基础类中的clone函数使用;
- getNbOutputs: 该层返回输出的张量个数,
- getOutputDimensions:返回输出的张量维度(返回多个张量咋写? 估计会根据index返回不同的Dims结构),
- configureWithFormat :根据数据个数做出一些调整,反正会传入一个DataType参数,看程序是否需要做一些调整,如果是要实现INT8,datatype在这里配置。
- initialize: 做一些初始化的工作,有些工作放到了析构函数中做
- terminate: 跟initialize 对应,有些工作放到了构造函数中做
- serialize: 把一些用得着的数据写入一个buffer
- getSerializationSize: 返回serialize时写了多少个字节到buffer了
- getWorkSpaceSize: 都是直接返回0了
- enqueue: 根据输入数据进行计算处理得到输出,一般用刀了GPU 加速,所以这个函数会在 .cu文件中实现
- 另外Plugin类的实现中还需要两个构造函数和一个析构函数
- 两个构造函数分别完成的工作是 1、通过API调用或者网络结构构造类,函数参数自定义,因为在pluginfactory中自主调用构造。2 通过反序列化完成参数读入进行的plugin构造.对应构造函数plugin(const void* data,int length)
- clone:This function is called by the Builder prior to initialize() . It provides an opportunity for the layer to make algorithm choices on the basis of its weights, dimensions, and maximum batch size.
creator类说明
- creator类是用于与tensorrt直接调用相关的接口类;
- 其会有一些基础函数,包括:
- getPluginName:获取该plugin的名字
- getPluginVersion: 获取该plugin的版本号
- getFieldNames:用于获取layer中的参数;
- createPlugin:标准接口,会由tensorrt内部进行调用,因此格式固定(const char* name, const PluginFieldCollection* fc);
- deserializePlugin :用于反序列化,格式固定(const char* name, const void* serialData, size_t serialLength);
调用流程
- 初始化***plugincreator,获取pluginname 和 pluginversion;
- 启动 plugincreator的constructor函数,并调用plugin的constructor函数,将参数设置如plugin中;
- 调用clone函数,对plugin进行clone;
- 调用supportFormat函数(确定tensorrt计算的type–kFLOAT和通道的循序–kNCHW)-> 再次调用clone函数 -> plugin constructor -> configurePlugin(用于做参数判定,如果配置的参数不满足要求则直接报错) -> clone -> destroy -> initialize -> destroy;
- 序列化过程:getType -> Namespace -> serialization size -> serialize -> terminate -> destroy;
- 反序列化: plugincreator constructor -> creator deserialize -> plugin deserialize -> initialize -> enqueue -> terminate -> destroy。
custom plugin layer(with python)
https://www.cnblogs.com/shouhuxianjian/p/10532950.html
attention ocr介绍
文章对attention ocr的全流程进行了讲解,内容非常好,尤其是对与网络的介绍。 整体流程为:encoder+decoder ----------encoder采用CNN+biLSTM模型 ------- decoder采用Attention模型
attention ocr最终回归出来的结果是69(69个字符)1728(2472)
其中包含了lstm,但是lstm在tensorrt的7.0以下版本都没有实现,所以最好是转onnx然后做trt的实现。
plugin编写流程以及注意事项说明(基于tensorflow,pb转uff),本案例以rensorrt source code中batchedNMSPlugin进行介绍
- pb转uff时,在create_plugin_note的时候可以设置参数的输入;
此参数可以通过“plugincreator”接口进行设置并调用:
mPluginAttributes.emplace_back(PluginField("shareLocation", nullptr, PluginFieldType::kINT32, 1));
const PluginField* fields = fc->fields;
mClipBoxes = true;
for (int i = 0; i < fc->nbFields; ++i)
{
const char* attrName = fields[i].name;
if (!strcmp(attrName, "shareLocation"))
{
params.shareLocation = *(static_cast<const bool*>(fields[i].data));
}
}
- 并通过plgin的构造函数将参数传递到plugin函数。
BatchedNMSPlugin* plugin = new BatchedNMSPlugin(params);
- plugin函数编写(需注意以下几个函数的编写)。
a. 多个构造函数,其中一个给clone调用,一个给plugincreator调用并传参,还有一个用于序列化,用于反序列化的函数需要注意,需要将需传递到cu代码参数进行read,有多少个参数就read多少个参数;
BatchedNMSPlugin::BatchedNMSPlugin(const void* data, size_t length)
{
const char *d = reinterpret_cast<const char*>(data), *a = d;
param = read<NMSParameters>(d);
boxesSize = read<int>(d);
scoresSize = read<int>(d);
numPriors = read<int>(d);
mClipBoxes = read<bool>(d);
ASSERT(d == a + length);
}
b. getNbOutputs需要返回编写的plugin的输出数量,这个需要预先知道并填写;
int BatchedNMSPlugin::getNbOutputs() const
{
return 4;
}
c. getOutputDimensions函数,可以通过该函数获取输入的参数数量以及每个参数的维度信息
inputs[0].d[0]为输入1维度1的信息;
inputs[1].d[0]为输入2维度1的信息;
inputs[0].nbDims可以获取输入一有多少维;
nbInputDims为输入数量;
此处也可以获取一些全局变量,例如输入数据的维度,可以将其作为全局变量,以备enqueue函数使用。
此处需要预先知道输出的维度是多少并返回,如果只有2为可以使用return DimsHW( ,);输出三维可以使用return DimsCHW( ,,);等
d. enqueue函数是trt inference的主要入口,其中调用的函数需要编写cu文件来实现或者直接通过cuda的标准接口实现。
例如:
pluginStatus_t status = nmsInference(stream, batchSize, boxesSize, scoresSize, param.shareLocation,
param.backgroundLabelId, numPriors, param.numClasses, param.topK, param.keepTopK, param.scoreThreshold,
param.iouThreshold, DataType::kFLOAT, locData, DataType::kFLOAT, confData, keepCount, nmsedBoxes, nmsedScores,
nmsedClasses, workspace, param.isNormalized, false, mClipBoxes);
该接口需要在plugin/common/kernal.h里面进行定义,并在plugin/common/kernal/底下编写cu code。
e. getSerializationSize用来说明你在序列化写的数据长度或者是你在反序列化读的长度(是一致的);
f. serialize用于序列化,需要写参数。
void BatchedNMSPlugin::serialize(void* buffer) const
{
char *d = reinterpret_cast<char*>(buffer), *a = d;
write(d, param);
write(d, boxesSize);
write(d, scoresSize);
write(d, numPriors);
write(d, mClipBoxes);
ASSERT(d == a + getSerializationSize());
}
g. configurePlugin和getOutputDimensions的输入数据基本类似,同样可以获取输入参数的维度信息,并设置全局变量;
这个函数相对于getOutputDimensions而言,就是这个函数是必须要实现的且全局变量必须赋予值。
里面也可以对输入数据进行维度的报错。
h.clone函数需要调用plugin的构造并传递全局参数。
以上几个函数是必须要实现的,其他的函数参考示例就行。