Tensorrt- build and run googleNet in Tensorrt

1 description

  1. build engine
bool SampleGoogleNet::build()
{
    auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
    if(!builder)
    {
        return false;
    }
    auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetwork());
    if(!network)
    {
        return false;
    }
    auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
    if(!config)
    {
        return false;
    }
    auto parser = SampleUniquePtr<nvcaffeparser1::ICaffeParser>(nvcaffeparser1::createCaffeParser());
    if(!parser)
    {
        return false;
    }
    constructNetwork(parser, network);
    builder->setMaxBatchSize(mParams.batchSize);
    config->setMaxWorkspaceSize(16_MiB);
    samplesCommon::enableDLA(builder.get(), config.get(),mParams.dlaCore);
    mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(builder->buildEngineWithConfig(*network,*config),samplesCommon::InferDeleter());
    if(!mEngine)
    {
        return false;
    }
    return true;
}
  1. infer
bool SampleGoogleNet::infer()
{
    samplesCommon::BufferManager buffers(mEngine, mParams.batchSize);

    auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
    if (!context)
    {
        return false;
    }
    for (auto& input : mParams.inputTensorNames)
    {
        const auto bufferSize = buffers.size(input);
        if (bufferSize == samplesCommon::BufferManager::kINVALID_SIZE_VALUE)
        {
            sample::gLogError << "input tensor missing: " << input << "\n";
            return EXIT_FAILURE;
        }
        memset(buffers.getHostBuffer(input), 0, bufferSize);
    }
    buffers.copyInputToDevice();

    bool status = context->execute(mParams.batchSize, buffers.getDeviceBindings().data());
    if (!status)
    {
        return false;
    }
    buffers.copyOutputToHost();

    return true;
}
  1. teardown
bool SampleGoogleNet::teardown()
{
    nvcaffeparser1::shutdownProtobufLibrary();
    return true;
}

2 code

#include "common/argsParser.h"
#include "common/buffers.h"
#include "common/common.h"
#include "common/logger.h"
#include "NvCaffeParser.h"
#include "NvInfer.h"
#include <cuda_runtime_api.h>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>

const std::string gSampleName = "TensorRT.sample_googlenet";
class SampleGoogleNet
{
    template<typename T>
    using SampleUniquePtr = std::unique_ptr<T, samplesCommon::InferDeleter>;
public:
    SampleGoogleNet(const samplesCommon::CaffeSampleParams& params)
        :mParams(params)
    {

    }
    bool build();
    bool infer();
    bool teardown();
    samplesCommon::CaffeSampleParams mParams;
private:
    void constructNetwork(SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser, SampleUniquePtr<nvinfer1::INetworkDefinition>& network);
    std::shared_ptr<nvinfer1::ICudaEngine> mEngine{nullptr};
};
bool SampleGoogleNet::build()
{
    auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
    if(!builder)
    {
        return false;
    }
    auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetwork());
    if(!network)
    {
        return false;
    }
    auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
    if(!config)
    {
        return false;
    }
    auto parser = SampleUniquePtr<nvcaffeparser1::ICaffeParser>(nvcaffeparser1::createCaffeParser());
    if(!parser)
    {
        return false;
    }
    constructNetwork(parser, network);
    builder->setMaxBatchSize(mParams.batchSize);
    config->setMaxWorkspaceSize(16_MiB);
    samplesCommon::enableDLA(builder.get(), config.get(),mParams.dlaCore);
    mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(builder->buildEngineWithConfig(*network,*config),samplesCommon::InferDeleter());
    if(!mEngine)
    {
        return false;
    }
    return true;
}
void SampleGoogleNet::constructNetwork(SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser, SampleUniquePtr<nvinfer1::INetworkDefinition>& network)
{
    const nvcaffeparser1::IBlobNameToTensor* blobNameToTensor = parser->parse(
        mParams.prototxtFileName.c_str(), mParams.weightsFileName.c_str(), *network, nvinfer1::DataType::kFLOAT);

    for (auto& s : mParams.outputTensorNames)
    {
        network->markOutput(*blobNameToTensor->find(s.c_str()));
    }
}
bool SampleGoogleNet::infer()
{
    samplesCommon::BufferManager buffers(mEngine, mParams.batchSize);

    auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
    if (!context)
    {
        return false;
    }
    for (auto& input : mParams.inputTensorNames)
    {
        const auto bufferSize = buffers.size(input);
        if (bufferSize == samplesCommon::BufferManager::kINVALID_SIZE_VALUE)
        {
            sample::gLogError << "input tensor missing: " << input << "\n";
            return EXIT_FAILURE;
        }
        memset(buffers.getHostBuffer(input), 0, bufferSize);
    }
    buffers.copyInputToDevice();

    bool status = context->execute(mParams.batchSize, buffers.getDeviceBindings().data());
    if (!status)
    {
        return false;
    }
    buffers.copyOutputToHost();

    return true;
}
bool SampleGoogleNet::teardown()
{
    nvcaffeparser1::shutdownProtobufLibrary();
    return true;
}
samplesCommon::CaffeSampleParams initializeSampleParams(const samplesCommon::Args& args)
{
    samplesCommon::CaffeSampleParams params;
    if (args.dataDirs.empty())
    {
        params.dataDirs.push_back("data/");
    }
    else
    {
        params.dataDirs = args.dataDirs;
    }

    params.prototxtFileName = locateFile("googlenet.prototxt", params.dataDirs);
    params.weightsFileName = locateFile("googlenet.caffemodel", params.dataDirs);
    params.inputTensorNames.push_back("data");
    params.batchSize = 4;
    params.outputTensorNames.push_back("prob");
    params.dlaCore = args.useDLACore;

    return params;
}
void printHelpInfo()
{
    std::cout
        << "Usage: ./sample_googlenet [-h or --help] [-d or --datadir=<path to data directory>] [--useDLACore=<int>]\n";
    std::cout << "--help          Display help information\n";
    std::cout << "--datadir       Specify path to a data directory, overriding the default. This option can be used "
                 "multiple times to add multiple directories. If no data directories are given, the default is to use "
                 "data/samples/googlenet/ and data/googlenet/"
              << std::endl;
    std::cout << "--useDLACore=N  Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, "
                 "where n is the number of DLA engines on the platform."
              << std::endl;
}
int main(int argc, char** argv)
{
    samplesCommon::Args args;
    bool argsOK = samplesCommon::parseArgs(args, argc, argv);
    if (!argsOK)
    {
        sample::gLogError << "Invalid arguments" << std::endl;
        printHelpInfo();
        return EXIT_FAILURE;
    }

    if (args.help)
    {
        printHelpInfo();
        return EXIT_SUCCESS;
    }

    auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);

    sample::gLogger.reportTestStart(sampleTest);

    samplesCommon::CaffeSampleParams params = initializeSampleParams(args);
    SampleGoogleNet sample(params);

    sample::gLogInfo << "Building and running a GPU inference engine for GoogleNet" << std::endl;

    if (!sample.build())
    {
        return sample::gLogger.reportFail(sampleTest);
    }

    if (!sample.infer())
    {
        return sample::gLogger.reportFail(sampleTest);
    }

    if (!sample.teardown())
    {
        return sample::gLogger.reportFail(sampleTest);
    }

    sample::gLogInfo << "Ran " << argv[0] << " with: " << std::endl;

    std::stringstream ss;

    ss << "Input(s): ";
    for (auto& input : sample.mParams.inputTensorNames)
    {
        ss << input << " ";
    }

    sample::gLogInfo << ss.str() << std::endl;

    ss.str(std::string());

    ss << "Output(s): ";
    for (auto& output : sample.mParams.outputTensorNames)
    {
        ss << output << " ";
    }

    sample::gLogInfo << ss.str() << std::endl;

    return sample::gLogger.reportPass(sampleTest);
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值