TensorRT教程6:使用 C++ API 从头创建网络(重点)

使用 C++ API 从头创建网络(重点)

这种方法是兼容性最强,效率最高,也是难度最高的。

==那直接使用底层API有什么好处呢?看下表 ,可以看到对于 RNN,int8校准,不对称 padding 来说,NvCaffeParser是不支持的,只有 C++ API 和 Python API,才是支持的。所以说如果是针对很复杂的网络结构使用tensorRT,还是直接使用底层的 C++ API,和Python API 较好。==底层C++ API还可以解析像 darknet 这样的网络模型,因为它需要的就只是一个层名和权值参数对应的map文件。

FeatureC++PythonNvCaffeParserNvUffParser
CNNsyesyesyesyes
RNNsyesyesnono
INT8 CalibrationyesyesNANA
Asymmetric Paddingyesyesnono

可以参考tensorrtCV项目 github: https://github.com/wdhao/tensorrtCV

  • 不包含训练参数的层:缩放层, Relu层,Pooling层等。

  • 包含训练参数的层:卷积层,全连接层等,要先加载权值文件。

1、从头创建engine的9个基本步骤

step1:创建logger

step2:创建builder

step3:创建network

step4:向network中添加网络层

step5:设置并标记输出

step6:创建config并设置最大batchsize和最大工作空间

step7:创建engine

step8:序列化保存engine

step9:释放资源

2、头文件

//TensorRT头文件
#include <assert.h>
#include <fstream>
#include <sstream>
#include <iostream>
#include <cmath>
#include <sys/stat.h>
#include <cmath>
#include <time.h>
#include <cuda_runtime_api.h>
#include <unordered_map>
#include <algorithm>
#include <float.h>
#include <string.h>
#include <chrono>
#include <iterator>

#include "NvInfer.h"
#include "NvCaffeParser.h"
#include "common.h"

#include "BatchStream.h"
#include "LegacyCalibrator.h"

using namespace nvinfer1;
using namespace nvcaffeparser1;

static Logger gLogger;

3、创建流程

//step1:创建logger:日志记录器
class Logger : public ILogger           
 {
     void log(Severity severity, const char* msg) override
     {
         // suppress info-level messages
         if (severity != Severity::kINFO)
             std::cout << msg << std::endl;
     }
 } gLogger;


//step2:创建builder,默认是FP32模式
IBuilder* builder = createInferBuilder(gLogger);

//step3:创建network,动态shape网络,用户不得在显式精度模式下使用校准器或提供动态范围。
INetworkDefinition* network = builder->createNetworkV2(1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));

//step4:向network中添加网络层
//添加输入层data(一个网路可以有多个输入)
auto data = network->addInput(INPUT_BLOB_NAME, DataType::kFloat, Dims3{-1, 1, INPUT_H, INPUT_W});
assert(data);

//加载权值文件,加载一次即可,mnistapi.wts 文件存放着网络中各个层间的权值系数
std::map<std::string, Weights> weightMap = loadWeights(locateFile("mnistapi.wts"));

//添加卷基层conv1
auto conv1 = network->addConvolution(*data->getOutput(0), 20, DimsHW{5, 5}, weightMap["conv1filter"],weightMap["conv1bias"]);
conv1->setStride(DimsHW{1, 1});
assert(conv1);
//添加池化层
auto pool1 = network->addPooling(*conv1->getOutput(0), PoolingType::kMAX, DimsHW{2, 2});
pool1->setStride(DimsHW{2, 2});
assert(pool1);
//添加全连接层
auto ip1 = network->addFullyConnected(*pool1->getOutput(0), 500, weightMap["ip1filter"], weightMap["ip1bias"]);
assert(ip1);
//添加激活层
auto relu1 = network->addActivation(*ip1->getOutput(0), ActivationType::kRELU);
assert(relu1);
//添加softmax层,得到概率分布 prob
auto prob = network->addSoftMax(*relu1->getOutput(0));
assert(prob);

//step5:设置并标记输出,一定要标记输出,标记的输出tensor才不会被优化
prob->getOutput(0)->setName(OUTPUT_BLOB_NAME);
network->markOutput(*prob->getOutput(0));

//step6:创建config并设置最大batchsize和最大工作空间
IBuilderConfig* config = builder->createBuilderConfig();
config->setMaxBatchSize(maxBatchSize);//设置最大batchsize
config->setMaxWorkspaceSize(1 << 30);//2^30 ,这里是1G

//step7:创建engine
ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);//层间融合和张量融合在这里执行
assert(engine);

//step8:序列化保存engine到planfile
IHostMemory *serializedModel = engine->serialize();
assert(serializedModel != nullptr)
std::ofstream p("xxxxx.engine");
p.write(reinterpret_cast<const char*>(serializedModel->data()), serializedModel->size());


//step9:释放资源
serializedModel->destroy();
engine->destroy();
network->destroy();
config->destroy();
builder->destroy();
//使用底层C++API创建网络时需要释放权重文件占用内存
for (auto& mem : weightMap)
{
    free((void*) (mem.second.values));
}
//pluginFactory.destroyPlugin();//自定义插件销毁
//shutdownProtobufLibrary();// 关闭protobuf(原始缓冲区)库
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

米斯特龙_ZXL

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值