1. 前言
上一节对TensorRT做了介绍,然后科普了TensorRT优化方式以及讲解在Windows下如何安装TensorRT6.0,最后还介绍了如何编译一个官方给出的手写数字识别例子获得一个正确的预测结果。这一节我将结合TensorRT官方给出的一个例程来介绍TensorRT的使用,这个例程是使用LeNet完成MNIST手写数字识别,例程所在的目录为:
2. 代码解析
按照上一节的讲解,我们知道TensorRT的例程主要是分为Build和Deployment(infer)这两个步骤,接下来我们就按照参数初始化,Build,Deployment这个顺序来看看代码。
2.1 主函数
sampleMNIST例程的主函数代码实现如下:
int main(int argc, char** argv)
{
// 参数解析
samplesCommon::Args args;
bool argsOK = samplesCommon::parseArgs(args, argc, argv);
if (!argsOK)
{
gLogError << "Invalid arguments" << std::endl;
printHelpInfo();
return EXIT_FAILURE;
}
// 打印帮助信息
if (args.help)
{
printHelpInfo();
return EXIT_SUCCESS;
}
auto sampleTest = gLogger.defineTest(gSampleName, argc, argv);
gLogger.reportTestStart(sampleTest);
// 使用命令行参数初始化params结构的成员
samplesCommon::CaffeSampleParams params = initializeSampleParams(args);
// 构造SampleMNIST对象
SampleMNIST sample(params);
gLogInfo << "Building and running a GPU inference engine for MNIST" << std::endl;
// Build 此函数通过解析caffe模型创建MNIST网络,并构建用于运行MNIST(mEngine)的引擎
if (!sample.build())
{
return gLogger.reportFail(sampleTest);
}
// 前向推理如果没成功,用gLogger报告状态
if (!sample.infer())
{
return gLogger.reportFail(sampleTest);
}
// 用于清除示例类中创建的任何状态,内存释放
if (!sample.teardown())
{
return gLogger.reportFail(sampleTest);
}
// 报告例子运行成功
return gLogger.reportPass(sampleTest);
}
可以清晰的看到代码主要分为参数初始化,Build,Infer这三大部分,最后的输出结果是下面这样。
2.2 参数初始化
参数初始化主要由initializeSampleParams函数来完成,这个函数的详细注释如下,具体就是根据输入数据和网络文件所在的文件夹去读取LeNet的Caffe原始模型文件和均值文件,另外设置一些如输出Tensor名字,batch大小,运行时精度模式等关键参数,最后返回一个params对象。注意这里使用的LeNet模型是Caffe的原始模型,因为TensorRT是直接支持Caffe的原始模型解析的,但例如Pytorch模型之类的还要进行转换,这在以后的文章中会涉及到。
//!
//! 简介: 使用命令行参数初始化params结构的成员
//!
samplesCommon::CaffeSampleParams initializeSampleParams(const samplesCommon::Args& args)
{
samplesCommon::CaffeSampleParams params;
if (args.dataDirs.empty()) //!< 如果用户未提供目录路径,则使用默认目录
{
params.dataDirs.push_back("data/mnist/");
params.dataDirs.push_back("data/samples/mnist/");
}
else //!< 使用用户提供的目录路径
{
params.dataDirs = args.dataDirs;
}
params.prototxtFileName = locateFile("mnist.prototxt", params.dataDirs); //读取params.dataDirs文件夹下的mnist.prototxt
params.weightsFileName = locateFile("mnist.caffemodel", params.dataDirs); //读取params.dataDirs文件夹下的mnist.caffemodel
params.meanFileName = locateFile("mnist_mean.binarypr