0. 前言
-
本文提到的
sampleMNISTAPI
与之前0. 前言 -
一点点疑问:TensorRT要使用ONNX模型应该有两种方式,一种是像本例一样,直接在程序中转换ONNX模型形式,另外还有一种是通过官方工具先将ONNX模型转换为engine文件,不知道这两种方式有什么区别。
1. ONNX 模型转换
- 其他代码都不说了,就仔细看看
SampleOnnxMNIST::build()
函数。
1.1. build 函数详解
- 构建网络的流程基本上是
- 构建builder
- 构建空白network对象
- 构建buildConfig参数
- 构建Onnx模型解析器
- 通过解析器将模型结构保存在network对象中
- 设置一些模型参数(比如模型量化)
- 验证结果
bool SampleOnnxMNIST::build()
{
// 构建模型builder
auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
if (!builder)
{
return false;
}
// 构建空白network对象
const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
if (!network)
{
return false;
}
// 创建BuildConfig,我也不知道是干啥用的
auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
if (!config)
{
return false;
}
// 构建Onnx模型解析器
auto parser
= SampleUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger()));
if (!parser)
{
return false;
}
// 构建模型,通过parser解析,并将解析结果导入network中
auto constructed = constructNetwork(builder, network, config, parser);
if (!constructed)
{
return false;
}
mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter());
if (!mEngine)
{
return false;
}
// 验证结果
assert(network->getNbInputs() == 1);
mInputDims = network->getInput(0)->getDimensions();
assert(mInputDims.nbDims == 4);
assert(network->getNbOutputs() == 1);
mOutputDims = network->getOutput(0)->getDimensions();
assert(mOutputDims.nbDims == 2);
return true;
}
- 前一步的核心就是
constrctNetwork
,即通过parser解析模型并保存到network中
//!
//! \brief Uses a ONNX parser to create the Onnx MNIST Network and marks the
//! output layers
//!
//! \param network Pointer to the network that will be populated with the Onnx MNIST network
//!
//! \param builder Pointer to the engine builder
//!
bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
SampleUniquePtr<nvonnxparser::IParser>& parser)
{
// 注意,构建解析器的时候就已经把network对象作为参数传入了
auto parsed = parser->parseFromFile(locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(),
static_cast<int>(sample::gLogger.getReportableSeverity()));
if (!parsed)
{
return false;
}
// 模型量化,不知道跟onnx_tensorrt工具有啥区别
config->setMaxWorkspaceSize(16_MiB);
if (mParams.fp16)
{
config->setFlag(BuilderFlag::kFP16);
}
if (mParams.int8)
{
config->setFlag(BuilderFlag::kINT8);
samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f);
}
// 这里的 DLA 就是 Deep Learning Accelerator
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dla_layers
samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
return true;
}