使用TensorRT对caffe和pytorch onnx模型进行fp32和fp16推理

本文首发于个人博客https://kezunlin.me/post/bcdfb73c/,欢迎阅读最新内容!

tensorrt fp32 fp16 tutorial with caffe pytorch minist model

Series

Code Example

include headers

#include <assert.h>
#include <sys/stat.h>
#include <time.h>

#include <iostream>
#include <fstream>
#include <sstream>
#include <iomanip>
#include <cmath>
#include <algorithm>

#include <cuda_runtime_api.h>

#include "NvCaffeParser.h"
#include "NvOnnxConfig.h"
#include "NvOnnxParser.h"
#include "NvInfer.h"
#include "common.h"

using namespace nvinfer1;
using namespace nvcaffeparser1;

static Logger gLogger;

// Attributes of MNIST Caffe model
static const int INPUT_H = 28;
static const int INPUT_W = 28;
static const int OUTPUT_SIZE = 10;
//const char* INPUT_BLOB_NAME = "data";
const char* OUTPUT_BLOB_NAME = "prob";
const std::string mnist_data_dir = "data/mnist/";


// Simple PGM (portable greyscale map) reader
void readPGMFile(const std::string& fileName, uint8_t buffer[INPUT_H * INPUT_W])
{
    readPGMFile(fileName, buffer, INPUT_H, INPUT_W);
}

caffe model to tensorrt

void caffeToTRTModel(const std::string& deployFilepath,       // Path of Caffe prototxt file
                     const std::string& modelFilepath,        // Path of Caffe model file
                     const std::vector<std::string>& outputs, // Names of network outputs
                     unsigned int maxBatchSize,               // Note: Must be at least as large as the batch we want to run with
                     IHostMemory*& trtModelStream)            // Output buffer for the TRT model
{
    // Create builder
    IBuilder* builder = createInferBuilder(gLogger);

    // Parse caffe model to populate network, then set the outputs
    std::cout << "Reading Caffe prototxt: " << deployFilepath << "\n";
    std::cout << "Reading Caffe model: " << modelFilepath << "\n";
    INetworkDefinition* network = builder->createNetwork();
    ICaffeParser* parser = createCaffeParser();

    bool useFp16 = builder->platformHasFastFp16();
    std::cout << "platformHasFastFp16: " << useFp16 << "\n";

    bool useInt8 = builder->platformHasFastInt8();
    std::cout << "platformHasFastInt8: " <<
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值