Tensorrt_mnist_tensorflow

#include "common/argsParser.h"
#include "common/buffers.h"
#include "common/common.h"
#include "common/logger.h"

#include "NvUffParser.h"
#include "NvInfer.h"
#include <cuda_runtime_api.h>

#include <algorithm>
#include <cassert>
#include <cstdlib>
#include <chrono>
#include <fstream>
#include <iostream>
#include <string>
#include <sys/stat.h>
#include <unordered_map>
#include <vector>
const std::string gSampleName = "TensorRT.sample_uff_mnist";
class SampleUffMNIST
{
    template<typename T>
    using SampleUniquePtr = std::unique_ptr<T,samplesCommon::InferDeleter>;
public:
    SampleUffMNIST(const samplesCommon::UffSampleParams& params)
        :mParams(params)
    {

    }
    bool build();
    bool infer();
    bool teardown();
private:
    bool constructNetwork(SampleUniquePtr<nvuffparser::IUffParser>& parser,SampleUniquePtr<nvinfer1::INetworkDefinition>& network);
    bool processInput(const samplesCommon::BufferManager& buffers, const std::string& inputTensorName, int inputFileIdx)const;
    bool verifyOutput(const samplesCommon::BufferManager& buffers,const std::string& outputTensorName,int groundTruthDigit)const;

    std::shared_ptr<nvinfer1::ICudaEngine> mEngine{nullptr};
    samplesCommon::UffSampleParams mParams;

    nvinfer1::Dims mInputDims;
    const int kDIGITS{10};
};
bool SampleUffMNIST::build()
{
    auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
    if(!builder)
    {
        return false;
    }
    auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetwork());
    if(!network)
    {
        return false;
    }
    auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
    if(!config)
    {
        return false;
    }
    auto parser = SampleUniquePtr<nvuffparser::IUffParser>(nvuffparser::createUffParser());
    if(!parser)
    {
        return false;
    }
    if(!constructNetwork(parser,network))
    {
        return false;
    }
    builder->setMaxBatchSize(mParams.batchSize);
    config->setMaxWorkspaceSize(16_MiB);
    config->setFlag(BuilderFlag::kGPU_FALLBACK);
    //config->setFlag(BuilderFlag::kSTRICT_TYPES);
    if(mParams.fp16)
    {
        config->setFlag(BuilderFlag::kFP16);
    }
    if(mParams.int8)
    {
        config->setFlag(BuilderFlag::kINT8);
    }
    samplesCommon::enableDLA(builder.get(),config.get(),mParams.dlaCore);

    mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(builder->buildEngineWithConfig(*network,*config),samplesCommon::InferDeleter());
    if(!mEngine)
    {
        return false;
    }
    assert(network->getNbInputs() == 1);
    mInputDims = network->getInput(0)->getDimensions();
    assert(mInputDims.nbDims == 3);

    return true;
}
bool SampleUffMNIST::constructNetwork(SampleUniquePtr<nvuffparser::IUffParser>& parser, SampleUniquePtr<nvinfer1::INetworkDefinition>& network)
{
    assert(mParams.inputTensorNames.size()==1);
    assert(mParams.outputTensorNames.size() == 1);
    parser->registerInput(mParams.inputTensorNames[0].c_str(),nvinfer1::Dims3(1,28,28),nvuffparser::UffInputOrder::kNCHW);
    parser->registerOutput(mParams.outputTensorNames[0].c_str());
    parser->parse(mParams.uffFileName.c_str(),*network,nvinfer1::DataType::kFLOAT);

    if(mParams.int8)
    {
        samplesCommon::setAllTensorScales(network.get(),127.0f,127.0f);
    }
    return true;

}
bool SampleUffMNIST::infer()
{
    samplesCommon::BufferManager buffers(mEngine,mParams.batchSize);
    auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
    if(!context)
    {
        return false;
    }
    bool outputCorrect = true;
    float total = 0;
    // Try to infer each digit 0-9
    for (int digit = 0; digit < kDIGITS; digit++)
    {
        if (!processInput(buffers, mParams.inputTensorNames[0], digit))
        {
            return false;
        }
        // Copy data from host input buffers to device input buffers
        buffers.copyInputToDevice();

        const auto t_start = std::chrono::high_resolution_clock::now();

        // Execute the inference work
        if (!context->execute(mParams.batchSize, buffers.getDeviceBindings().data()))
        {
            return false;
        }

        const auto t_end = std::chrono::high_resolution_clock::now();
        const float ms = std::chrono::duration<float, std::milli>(t_end - t_start).count();
        total += ms;

        // Copy data from device output buffers to host output buffers
        buffers.copyOutputToHost();

        // Check and print the output of the inference
        outputCorrect &= verifyOutput(buffers, mParams.outputTensorNames[0], digit);
    }

    total /= kDIGITS;

    sample::gLogInfo << "Average over " << kDIGITS << " runs is " << total << " ms." << std::endl;

    return outputCorrect;
}
bool SampleUffMNIST::processInput(const samplesCommon::BufferManager& buffers, const std::string& inputTensorName, int inputFileIdx) const
{
    const int inputH = mInputDims.d[1];
    const int inputW = mInputDims.d[2];

    std::vector<uint8_t> fileData(inputH*inputW);
    readPGMFile(locateFile(std::to_string(inputFileIdx) + ".pgm",mParams.dataDirs),fileData.data(),inputH,inputW);

    sample::gLogInfo<<"Input:\n";
    for (int i = 0; i < inputH * inputW; i++)
    {
        sample::gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % inputW) ? "" : "\n");
    }
    sample::gLogInfo << std::endl;

    float* hostInputBuffer = static_cast<float*>(buffers.getHostBuffer(inputTensorName));
    for(int i=0;i<inputH*inputW;i++)
    {
        hostInputBuffer[i]= 1.0-float(fileData[i])/255.0;
    }
    return true;

}
bool SampleUffMNIST::verifyOutput(const samplesCommon::BufferManager& buffers, const std::string& outputTensorName, int groundTruthDigit) const
{
    const float* prob= static_cast<const float*>(buffers.getHostBuffer(outputTensorName));

    sample::gLogInfo <<"Output:\n";
    float val{0.0f};
    int idx{0};
    for(int i=0;i<kDIGITS;i++)
    {
        if(val<prob[i])
        {
            val=prob[i];
            idx=i;
        }
    }
    // Print output values for each index
    for (int j = 0; j < kDIGITS; j++)
    {
        sample::gLogInfo << j << "=> " << std::setw(10) << prob[j] << "\t : ";

        // Emphasize index with highest output value
        if (j == idx)
        {
            sample::gLogInfo << "***";
        }
        sample::gLogInfo << "\n";
    }

    sample::gLogInfo <<std::endl;
    return (idx == groundTruthDigit);
}
bool SampleUffMNIST::teardown()
{
    nvuffparser::shutdownProtobufLibrary();
    return true;
}
samplesCommon::UffSampleParams initializeSampleParams(const samplesCommon::Args& args)
{
    samplesCommon::UffSampleParams params;
    if(args.dataDirs.empty())
    {
        params.dataDirs.push_back("./data/");
//        params.dataDirs.push_back("./data/");
    }
    else
    {
        params.dataDirs = args.dataDirs;
    }

    params.uffFileName = locateFile("lenet5.uff", params.dataDirs);
    params.inputTensorNames.push_back("in");
    params.batchSize = 1;
    params.outputTensorNames.push_back("out");
    params.dlaCore=args.useDLACore;
    params.int8=args.runInInt8;
    params.fp16=args.runInFp16;
    return params;

}
void printHelpInfo()
{
    std::cout
        << "Usage: ./sample_mnist [-h or --help] [-d or --datadir=<path to data directory>] [--useDLACore=<int>]\n";
    std::cout << "--help          Display help information\n";
    std::cout << "--datadir       Specify path to a data directory, overriding the default. This option can be used "
                 "multiple times to add multiple directories. If no data directories are given, the default is to use "
                 "(data/samples/mnist/, data/mnist/)"
              << std::endl;
    std::cout << "--useDLACore=N  Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, "
                 "where n is the number of DLA engines on the platform."
              << std::endl;
    std::cout << "--int8          Run in Int8 mode.\n";
    std::cout << "--fp16          Run in FP16 mode.\n";
}
int main(int argc,char** argv)
{
    samplesCommon::Args args;
    bool argsOK= samplesCommon::parseArgs(args,argc,argv);
    if(!argsOK)
    {
        sample::gLogError << "Invalid arguments" << std::endl;
        printHelpInfo();
        return EXIT_FAILURE;
    }
    if (args.help)
    {
        printHelpInfo();
        return EXIT_SUCCESS;
    }
    auto sampleTest=sample::gLogger.defineTest(gSampleName,argc,argv);
    sample::gLogger.reportTestStart(sampleTest);

    samplesCommon::UffSampleParams params = initializeSampleParams(args);
    SampleUffMNIST sample(params);
    sample::gLogInfo <<"Building and running a GPU inference engine for MNIST" << std::endl;
    if (!sample.build())
    {
        return sample::gLogger.reportFail(sampleTest);
    }

    if (!sample.infer())
    {
        return sample::gLogger.reportFail(sampleTest);
    }

    if (!sample.teardown())
    {
        return sample::gLogger.reportFail(sampleTest);
    }

    return sample::gLogger.reportPass(sampleTest);
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值