mnist tensorrt 运行_深度学习算法优化系列十八 | TensorRT Mnist数字识别使用示例

1. 前言

上一节对TensorRT做了介绍,然后科普了TensorRT优化方式以及讲解在Windows下如何安装TensorRT6.0,最后还介绍了如何编译一个官方给出的手写数字识别例子获得一个正确的预测结果。这一节我将结合TensorRT官方给出的一个例程来介绍TensorRT的使用,这个例程是使用LeNet完成MNIST手写数字识别,例程所在的目录为:

2. 代码解析

按照上一节的讲解,我们知道TensorRT的例程主要是分为Build和Deployment(infer)这两个步骤,接下来我们就按照参数初始化,Build,Deployment这个顺序来看看代码。

2.1 主函数

sampleMNIST例程的主函数代码实现如下:

int main(int argc, char** argv)

{

// 参数解析

samplesCommon::Args args;

bool argsOK = samplesCommon::parseArgs(args, argc, argv);

if (!argsOK)

{

gLogError << "Invalid arguments" << std::endl;

printHelpInfo();

return EXIT_FAILURE;

}

// 打印帮助信息

if (args.help)

{

printHelpInfo();

return EXIT_SUCCESS;

}

auto sampleTest = gLogger.defineTest(gSampleName, argc, argv);

gLogger.reportTestStart(sampleTest);

// 使用命令行参数初始化params结构的成员

samplesCommon::CaffeSampleParams params = initializeSampleParams(args);

// 构造SampleMNIST对象

SampleMNIST sample(params);

gLogInfo << "Building and running a GPU inference engine for MNIST" << std::endl;

// Build 此函数通过解析caffe模型创建MNIST网络,并构建用于运行MNIST(mEngine)的引擎

if (!sample.build())

{

return gLogger.reportFail(sampleTest);

}

// 前向推理如果没成功,用gLogger报告状态

if (!sample.infer())

{

return gLogger.reportFail(sampleTest);

}

// 用于清除示例类中创建的任何状态,内存释放

if (!sample.teardown())

{

return gLogger.reportFail(sampleTest);

}

// 报告例子运行成功

return gLogger.reportPass(sampleTest);

}

可以清晰的看到代码主要分为参数初始化,Build,Infer这三大部分,最后的输出结果是下面这样。

2.2 参数初始化

参数初始化主要由initializeSampleParams函数来完成,这个函数的详细注释如下,具体就是根据输入数据和网络文件所在的文件夹去读取LeNet的Caffe原始模型文件和均值文件,另外设置一些如输出Tensor名字,batch大小,运行时精度模式等关键参数,最后返回一个params对象。注意这里使用的LeNet模型是Caffe的原始模型,因为TensorRT是直接支持Caffe的原始模型解析的,但例如Pytorch模型之类的还要进行转换,这在以后的文章中会涉及到。

//!

//! 简介: 使用命令行参数初始化params结构的成员

//!

samplesCommon::CaffeSampleParams initializeSampleParams(const samplesCommon::Args& args)

{

samplesCommon::CaffeSampleParams params;

if (args.dataDirs.empty()) //!< 如果用户未提供目录路径,则使用默认目录

{

params.dataDirs.push_back("data/mnist/");

params.dataDirs.push_back("data/samples/mnist/");

}

else //!< 使用用户提供的目录路径

{

params.dataDirs = args.dataDirs;

}

params.prototxtFileName = locateFile("mnist.prototxt", params.dataDirs); //读取params.dataDirs文件夹下的mnist.prototxt

params.weightsFileName = locateFile("mnist.caffemodel", params.dataDirs); //读取params.dataDirs文件夹下的mnist.caffemodel

params.meanFileName = locateFile("mnist_mean.binarypr

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值