1、环境准备
我是在jetson-nano上面跑的,版本信息如下:sudo jtop 后按5
jetpack自带了tensorRT,样例代码在/usr/src/tensorrt/samples
由于加载onnx需要额外的包,这里提供一下安装指南:
Jetson Xavier上tensorRT环境安装www.zybuluo.com
就看第四点就好
2、目录结构
.
├── build
├── CMakeLists.txt -------------①
├── src
│ ├── CMakeLists.txt ----------②
│ └── main.cpp
└── weights
└── yolov3-mytiny_98_0.96_warehouse.onnx
3、代码
- 1)CMakeLists.txt ------①
PROJECT(TRT)
ADD_SUBDIRECTORY(src)
- 2)CMakeLists.txt ------② CMakeLists写的比较丑…
SET(common_dir /usr/src/tensorrt/samples/common)
INCLUDE_DIRECTORIES(${common_dir})
SET(cuda_dir /usr/local/cuda-10.0/targets/aarch64-linux/include)
INCLUDE_DIRECTORIES(${cuda_dir})
set(ONNX_PARSE your_path/onnx-tensorrt-release-6.0/)
INCLUDE_DIRECTORIES(${ONNX_PARSE})
SET(LOG_CPP /usr/src/tensorrt/samples/common/logger.cpp)
ADD_EXECUTABLE(trt_test main.cpp ${LOG_CPP})
find_library(LIBONNX_PATH nvonnxparser your_path/onnx-tensorrt-release-6.0/build)
TARGET_LINK_LIBRARIES(trt_test ${LIBONNX_PATH})
find_library(LIBNVINFER nvinfer /usr/lib/aarch64-linux-gnu/)
TARGET_LINK_LIBRARIES(trt_test ${LIBNVINFER})
find_library(LIBCUDART cudart /usr/local/cuda-10.0/lib64/)
TARGET_LINK_LIBRARIES(trt_test ${LIBCUDART})
- 3)main.cpp
#include <iostream>
#include "argsParser.h"
#include "buffers.h"
#include "common.h"
#include "logger.h"
#include "BatchStream.h"
#include "EntropyCalibrator.h"
#include "NvOnnxParser.h"
#include "NvInfer.h"
using namespace nvinfer1;
using namespace nvonnxparser;
int main() {
samplesCommon::Args args;
// 1 加载onnx模型
IBuilder* builder = createInferBuilder(gLogger);
nvinfer1::INetworkDefinition* network = builder->createNetwork();
nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, gLogger);
const char* onnx_filename="./weights/yolov3-mytiny_98_0.96_warehouse.onnx";
parser->parseFromFile(onnx_filename, static_cast<int>(Logger::Severity::kWARNING));
for (int i = 0; i < parser->getNbErrors(); ++i)
{
std::cout << parser->getError(i)->desc() << std::endl;
}
std::cout << "successfully load the onnx model" << std::endl;
// 2、build the engine
unsigned int maxBatchSize=1;
builder->setMaxBatchSize(maxBatchSize);
IBuilderConfig* config = builder->createBuilderConfig();
config->setMaxWorkspaceSize(1 << 20);
ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
// 3、serialize Model
IHostMemory *gieModelStream = engine->serialize();
std::string serialize_str;
std::ofstream serialize_output_stream;
serialize_str.resize(gieModelStream->size());
memcpy((void*)serialize_str.data(),gieModelStream->data(),gieModelStream->size());
serialize_output_stream.open("./serialize_engine_output.trt");
serialize_output_stream<<serialize_str;
serialize_output_stream.close();
// 4、deserialize model
IRuntime* runtime = createInferRuntime(gLogger);
std::string cached_path = "./build/serialize_engine_output.trt";
std::ifstream fin(cached_path);
std::string cached_engine = "";
while (fin.peek() != EOF){
std::stringstream buffer;
buffer << fin.rdbuf();
cached_engine.append(buffer.str());
}
fin.close();
ICudaEngine* re_engine = runtime->deserializeCudaEngine(cached_engine.data(), cached_engine.size(), nullptr);
std::cout << "Hello, World!" << std::endl;
return 0;
}
4、异常
- 1)nvinfer1::INetworkDefinition* network = builder->createNetwork();
// 官网给的创建NetWork的方式,会报下面那个错误
const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
INetworkDefinition* network = builder->createNetworkV2(explicitBatch);
// 别人的代码中的,就不会报错
nvinfer1::INetworkDefinition* network = builder->createNetwork();
错误信息
ERROR: /home/webank/tqiu/env/onnx-tensorrt-release-6.0/builtin_op_importers.cpp:661 In function importBatchNormalization:
[6] Assertion failed: scale_weights.shape == weights_shape
Assertion failed: scale_weights.shape == weights_shape
successfully load the onnx model
[06/06/2020-19:17:23] [E] [TRT] Network must have at least one output
- 2)parser->parseFromFile(onnx_filename, ILogger::Severity::kWARNING);
// 官网写法会报如下错误
parser->parseFromFile(onnx_filename, ILogger::Severity::kWARNING);
// 正确写法
parser->parseFromFile(onnx_filename,static_cast<int>(Logger::Severity::kWARNING));
错误信息
error: no matching function for call to ‘nvonnxparser::IParser::parseFromFile(const char*&, nvinfer1::ILogger::Severity)’
parser->parseFromFile(onnx_filename, ILogger::Severity::kWARNING);