本文首发于个人博客https://kezunlin.me/post/bcdfb73c/,欢迎阅读最新内容!
tensorrt fp32 fp16 tutorial with caffe pytorch minist model
Series
- Part 1: install and configure tensorrt 4 on ubuntu 16.04
- Part 2: tensorrt fp32 fp16 tutorial
- Part 3: tensorrt int8 tutorial
Code Example
include headers
#include <assert.h>
#include <sys/stat.h>
#include <time.h>
#include <iostream>
#include <fstream>
#include <sstream>
#include <iomanip>
#include <cmath>
#include <algorithm>
#include <cuda_runtime_api.h>
#include "NvCaffeParser.h"
#include "NvOnnxConfig.h"
#include "NvOnnxParser.h"
#include "NvInfer.h"
#include "common.h"
using namespace nvinfer1;
using namespace nvcaffeparser1;
static Logger gLogger;
// Attributes of MNIST Caffe model
static const int INPUT_H = 28;
static const int INPUT_W = 28;
static const int OUTPUT_SIZE = 10;
//const char* INPUT_BLOB_NAME = "data";
const char* OUTPUT_BLOB_NAME = "prob";
const std::string mnist_data_dir = "data/mnist/";
// Simple PGM (portable greyscale map) reader
void readPGMFile(const std::string& fileName, uint8_t buffer[INPUT_H * INPUT_W])
{
readPGMFile(fileName, buffer, INPUT_H, INPUT_W);
}
caffe model to tensorrt
void caffeToTRTModel(const std::string& deployFilepath, // Path of Caffe prototxt file
const std::string& modelFilepath, // Path of Caffe model file
const std::vector<std::string>& outputs, // Names of network outputs
unsigned int maxBatchSize, // Note: Must be at least as large as the batch we want to run with
IHostMemory*& trtModelStream) // Output buffer for the TRT model
{
// Create builder
IBuilder* builder = createInferBuilder(gLogger);
// Parse caffe model to populate network, then set the outputs
std::cout << "Reading Caffe prototxt: " << deployFilepath << "\n";
std::cout << "Reading Caffe model: " << modelFilepath << "\n";
INetworkDefinition* network = builder->createNetwork();
ICaffeParser* parser = createCaffeParser();
bool useFp16 = builder->platformHasFastFp16();
std::cout << "platformHasFastFp16: " << useFp16 << "\n";
bool useInt8 = builder->platformHasFastInt8();
std::cout << "platformHasFastInt8: " <<