









#pragma once
 * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *     http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.


#include "NvInferRuntimeCommon.h"
#include <cassert>
#include <ctime>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <sstream>
#include <string>

using Severity = nvinfer1::ILogger::Severity;

class LogStreamConsumerBuffer : public std::stringbuf
    LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog)
        : mOutput(stream)
        , mPrefix(prefix)
        , mShouldLog(shouldLog)

    LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other)
        : mOutput(other.mOutput)

        // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence
        // std::streambuf::pptr() gives a pointer to the current position of the output sequence
        // if the pointer to the beginning is not equal to the pointer to the current position,
        // call putOutput() to log the output to the stream
        if (pbase() != pptr())

    // synchronizes the stream buffer and returns 0 on success
    // synchronizing the stream buffer consists of inserting the buffer contents into the stream,
    // resetting the buffer and flushing the stream
    virtual int sync()
        return 0;

    void putOutput()
        if (mShouldLog)
            // prepend timestamp
            std::time_t timestamp = std::time(nullptr);
            tm* tm_local = new tm();
            localtime_s(tm_local, &timestamp);
            std::cout << "[";
            std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/";
            std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/";
            std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-";
            std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":";
            std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":";
            std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] ";
            // std::stringbuf::str() gets the string contents of the buffer
            // insert the buffer contents pre-appended by the appropriate prefix into the stream
            mOutput << mPrefix << str();
            // set the buffer to empty
            // flush the stream

    void setShouldLog(bool shouldLog)
        mShouldLog = shouldLog;

    std::ostream& mOutput;
    std::string mPrefix;
    bool mShouldLog;

//! \class LogStreamConsumerBase
//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer
class LogStreamConsumerBase
    LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog)
        : mBuffer(stream, prefix, shouldLog)

    LogStreamConsumerBuffer mBuffer;

//! \class LogStreamConsumer
//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages.
//!  Order of base classes is LogStreamConsumerBase and then std::ostream.
//!  This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field
//!  in LogStreamConsumer and then the address of the buffer is passed to std::ostream.
//!  This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream.
//!  Please do not change the order of the parent classes.
class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream
    //! \brief Creates a LogStreamConsumer which logs messages with level severity.
    //!  Reportable severity determines if the messages are severe enough to be logged.
    LogStreamConsumer(Severity reportableSeverity, Severity severity)
        : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity)
        , std::ostream(&mBuffer) // links the stream buffer with the stream
        , mShouldLog(severity <= reportableSeverity)
        , mSeverity(severity)

    LogStreamConsumer(LogStreamConsumer&& other)
        : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog)
        , std::ostream(&mBuffer) // links the stream buffer with the stream
        , mShouldLog(other.mShouldLog)
        , mSeverity(other.mSeverity)

    void setReportableSeverity(Severity reportableSeverity)
        mShouldLog = mSeverity <= reportableSeverity;

    static std::ostream& severityOstream(Severity severity)
        return severity >= Severity::kINFO ? std::cout : std::cerr;

    static std::string severityPrefix(Severity severity)
        switch (severity)
        case Severity::kINTERNAL_ERROR: return "[F] ";
        case Severity::kERROR: return "[E] ";
        case Severity::kWARNING: return "[W] ";
        case Severity::kINFO: return "[I] ";
        case Severity::kVERBOSE: return "[V] ";
        default: assert(0); return "";

    bool mShouldLog;
    Severity mSeverity;

//! \class Logger
//! \brief Class which manages logging of TensorRT tools and samples
//! \details This class provides a common interface for TensorRT tools and samples to log information to the console,
//! and supports logging two types of messages:
//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal)
//! - Test pass/fail messages
//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is
//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location.
//! In the future, this class could be extended to support dumping test results to a file in some standard format
//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run).
//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger
//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT
//! library and messages coming from the sample.
//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the
//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger
//! object.

class Logger : public nvinfer1::ILogger
    Logger(Severity severity = Severity::kWARNING)
        : mReportableSeverity(severity)

    //! \enum TestResult
    //! \brief Represents the state of a given test
    enum class TestResult
        kRUNNING, //!< The test is running
        kPASSED,  //!< The test passed
        kFAILED,  //!< The test failed
        kWAIVED   //!< The test was waived

    //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger
    //! \return The nvinfer1::ILogger associated with this Logger
    //! TODO Once all samples are updated to use this method to register the logger with TensorRT,
    //! we can eliminate the inheritance of Logger from ILogger
    nvinfer1::ILogger& getTRTLogger()
        return *this;

    //! \brief Implementation of the nvinfer1::ILogger::log() virtual method
    //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
    //! inheritance from nvinfer1::ILogger
    void log(Severity severity, const char* msg) noexcept override
        LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;

    //! \brief Method for controlling the verbosity of logging output
    //! \param severity The logger will only emit messages that have severity of this level or higher.
    void setReportableSeverity(Severity severity)
        mReportableSeverity = severity;

    //! \brief Opaque handle that holds logging information for a particular test
    //! This object is an opaque handle to information used by the Logger to print test results.
    //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used
    //! with Logger::reportTest{Start,End}().
    class TestAtom
        TestAtom(TestAtom&&) = default;

        friend class Logger;

        TestAtom(bool started, const std::string& name, const std::string& cmdline)
            : mStarted(started)
            , mName(name)
            , mCmdline(cmdline)

        bool mStarted;
        std::string mName;
        std::string mCmdline;

    //! \brief Define a test for logging
    //! \param[in] name The name of the test.  This should be a string starting with
    //!                  "TensorRT" and containing dot-separated strings containing
    //!                  the characters [A-Za-z0-9_].
    //!                  For example, "TensorRT.sample_googlenet"
    //! \param[in] cmdline The command line used to reproduce the test
    //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
    static TestAtom defineTest(const std::string& name, const std::string& cmdline)
        return TestAtom(false, name, cmdline);

    //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments
    //!        as input
    //! \param[in] name The name of the test
    //! \param[in] argc The number of command-line arguments
    //! \param[in] argv The array of command-line arguments (given as C strings)
    //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
    static TestAtom defineTest(const std::string& name, int argc, char const* const* argv)
        auto cmdline = genCmdlineString(argc, argv);
        return defineTest(name, cmdline);

    //! \brief Report that a test has started.
    //! \pre reportTestStart() has not been called yet for the given testAtom
    //! \param[in] testAtom The handle to the test that has started
    static void reportTestStart(TestAtom& testAtom)
        reportTestResult(testAtom, TestResult::kRUNNING);
        testAtom.mStarted = true;

    //! \brief Report that a test has ended.
    //! \pre reportTestStart() has been called for the given testAtom
    //! \param[in] testAtom The handle to the test that has ended
    //! \param[in] result The result of the test. Should be one of TestResult::kPASSED,
    //!                   TestResult::kFAILED, TestResult::kWAIVED
    static void reportTestEnd(const TestAtom& testAtom, TestResult result)
        assert(result != TestResult::kRUNNING);
        reportTestResult(testAtom, result);

    static int reportPass(const TestAtom& testAtom)
        reportTestEnd(testAtom, TestResult::kPASSED);
        return EXIT_SUCCESS;

    static int reportFail(const TestAtom& testAtom)
        reportTestEnd(testAtom, TestResult::kFAILED);
        return EXIT_FAILURE;

    static int reportWaive(const TestAtom& testAtom)
        reportTestEnd(testAtom, TestResult::kWAIVED);
        return EXIT_SUCCESS;

    static int reportTest(const TestAtom& testAtom, bool pass)
        return pass ? reportPass(testAtom) : reportFail(testAtom);

    Severity getReportableSeverity() const
        return mReportableSeverity;

    //! \brief returns an appropriate string for prefixing a log message with the given severity
    static const char* severityPrefix(Severity severity)
        switch (severity)
        case Severity::kINTERNAL_ERROR: return "[F] ";
        case Severity::kERROR: return "[E] ";
        case Severity::kWARNING: return "[W] ";
        case Severity::kINFO: return "[I] ";
        case Severity::kVERBOSE: return "[V] ";
        default: assert(0); return "";

    //! \brief returns an appropriate string for prefixing a test result message with the given result
    static const char* testResultString(TestResult result)
        switch (result)
        case TestResult::kRUNNING: return "RUNNING";
        case TestResult::kPASSED: return "PASSED";
        case TestResult::kFAILED: return "FAILED";
        case TestResult::kWAIVED: return "WAIVED";
        default: assert(0); return "";

    //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity
    static std::ostream& severityOstream(Severity severity)
        return severity >= Severity::kINFO ? std::cout : std::cerr;

    //! \brief method that implements logging test results
    static void reportTestResult(const TestAtom& testAtom, TestResult result)
        severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "
            << testAtom.mCmdline << std::endl;

    //! \brief generate a command line string from the given (argc, argv) values
    static std::string genCmdlineString(int argc, char const* const* argv)
        std::stringstream ss;
        for (int i = 0; i < argc; i++)
            if (i > 0)
                ss << " ";
            ss << argv[i];
        return ss.str();

    Severity mReportableSeverity;


    //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE
    //! Example usage:
    //!     LOG_VERBOSE(logger) << "hello world" << std::endl;
    inline LogStreamConsumer LOG_VERBOSE(const Logger& logger)
        return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE);

    //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO
    //! Example usage:
    //!     LOG_INFO(logger) << "hello world" << std::endl;
    inline LogStreamConsumer LOG_INFO(const Logger& logger)
        return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO);

    //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING
    //! Example usage:
    //!     LOG_WARN(logger) << "hello world" << std::endl;
    inline LogStreamConsumer LOG_WARN(const Logger& logger)
        return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING);

    //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR
    //! Example usage:
    //!     LOG_ERROR(logger) << "hello world" << std::endl;
    inline LogStreamConsumer LOG_ERROR(const Logger& logger)
        return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR);

    //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR
    //         ("fatal" severity)
    //! Example usage:
    //!     LOG_FATAL(logger) << "hello world" << std::endl;
    inline LogStreamConsumer LOG_FATAL(const Logger& logger)
        return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR);

} // anonymous namespace



#pragma once
#include <algorithm> 
#include <fstream>
#include <iostream>
#include <opencv2/opencv.hpp>
#include <vector>
#include <chrono>
#include <cmath>
#include <numeric> // std::iota 

using  namespace cv;

#define CHECK(status) \
        auto ret = (status);\
        if (ret != 0)\
            std::cerr << "Cuda failure: " << ret << std::endl;\
    } while (0)
struct alignas(float) Detection {
	//center_x center_y w h
	float bbox[4];
	float conf;  // bbox_conf * cls_conf
	int class_id;
static inline cv::Mat preprocess_img(cv::Mat& img, int input_w, int input_h, std::vector<int>& padsize) {
	int w, h, x, y;
	float r_w = input_w / (img.cols*1.0);
	float r_h = input_h / (img.rows*1.0);
	if (r_h > r_w) {//宽大于高
		w = input_w;
		h = r_w * img.rows;
		x = 0;
		y = (input_h - h) / 2;
	else {
		w = r_h * img.cols;
		h = input_h;
		x = (input_w - w) / 2;
		y = 0;
	cv::Mat re(h, w, CV_8UC3);
	cv::resize(img, re, re.size(), 0, 0, cv::INTER_LINEAR);
	cv::Mat out(input_h, input_w, CV_8UC3, cv::Scalar(128, 128, 128));
	re.copyTo(out(cv::Rect(x, y, re.cols, re.rows)));
	padsize.push_back(x);// int newh = padsize[0], neww = padsize[1], padh = padsize[2], padw = padsize[3];

	return out;
cv::Rect get_rect(cv::Mat& img, float bbox[4], int INPUT_W, int INPUT_H) {
	int l, r, t, b;
	float r_w = INPUT_W / (img.cols * 1.0);
	float r_h = INPUT_H / (img.rows * 1.0);
	if (r_h > r_w) {

		l = bbox[0];
		r = bbox[2];
		t = bbox[1]- (INPUT_H - r_w * img.rows) / 2;
		b = bbox[3] - (INPUT_H - r_w * img.rows) / 2;
		l = l / r_w;
		r = r / r_w;
		t = t / r_w;
		b = b / r_w;
	else {
		l = bbox[0] - bbox[2] / 2.f - (INPUT_W - r_h * img.cols) / 2;
		r = bbox[0] + bbox[2] / 2.f - (INPUT_W - r_h * img.cols) / 2;
		t = bbox[1] - bbox[3] / 2.f;
		b = bbox[1] + bbox[3] / 2.f;
		l = l / r_h;
		r = r / r_h;
		t = t / r_h;
		b = b / r_h;

	return cv::Rect(l, t, r - l, b - t);


#pragma once
#include "NvInfer.h"
#include "cuda_runtime_api.h"
#include "NvInferPlugin.h"
#include "logging.h"
#include <opencv2/opencv.hpp>
#include "utils.h"
#include <string>
using namespace nvinfer1;
using namespace cv;

// stuff we know about the network and the input/output blobs
static const int batchSize = 1;
static const int INPUT_H = 640;
static const int INPUT_W = 640;
static const int _segWidth = 160;
static const int _segHeight = 160;
static const int _segChannels = 32;
static const int CLASSES = 80;
static const int Num_box = 25200;
static const int OUTPUT_SIZE = batchSize * Num_box * (CLASSES + 5 + _segChannels);//output0
static const int OUTPUT_SIZE1 = batchSize * _segChannels * _segWidth * _segHeight;//output1
static const int INPUT_SIZE = batchSize * 3 * INPUT_H * INPUT_W;//images

static const float CONF_THRESHOLD = 0.1;
static const float NMS_THRESHOLD = 0.5;
static const float MASK_THRESHOLD = 0.5;
const char* INPUT_BLOB_NAME = "images";
const char* OUTPUT_BLOB_NAME = "output0";//detect
const char* OUTPUT_BLOB_NAME1 = "output1";//mask

static float prob[OUTPUT_SIZE];       //box
static float prob1[OUTPUT_SIZE1];      //mask

static Logger gLogger;
struct OutputSeg {
	int id;             //结果类别id
	float confidence;   //结果置信度
	cv::Rect box;       //矩形框
	cv::Mat boxMask;       //矩形框内mask,节省内存空间和加快速度

struct OutputObject     
	std::vector<int> classIds;//结果id数组
	std::vector<float> confidences;//结果每个id对应置信度数组
	std::vector<cv::Rect> boxes;//每个id矩形框
	std::vector<std::vector<float>> picked_proposals;  //存储output0[:,:, 5 + _className.size():net_width]用以后续计算mask

const float color_list[80][3] =
	{0.000, 0.447, 0.741},
	{0.850, 0.325, 0.098},
	{0.929, 0.694, 0.125},
	{0.494, 0.184, 0.556},
	{0.466, 0.674, 0.188},
	{0.301, 0.745, 0.933},
	{0.635, 0.078, 0.184},
	{0.300, 0.300, 0.300},
	{0.600, 0.600, 0.600},
	{1.000, 0.000, 0.000},
	{1.000, 0.500, 0.000},
	{0.749, 0.749, 0.000},
	{0.000, 1.000, 0.000},
	{0.000, 0.000, 1.000},
	{0.667, 0.000, 1.000},
	{0.333, 0.333, 0.000},
	{0.333, 0.667, 0.000},
	{0.333, 1.000, 0.000},
	{0.667, 0.333, 0.000},
	{0.667, 0.667, 0.000},
	{0.667, 1.000, 0.000},
	{1.000, 0.333, 0.000},
	{1.000, 0.667, 0.000},
	{1.000, 1.000, 0.000},
	{0.000, 0.333, 0.500},
	{0.000, 0.667, 0.500},
	{0.000, 1.000, 0.500},
	{0.333, 0.000, 0.500},
	{0.333, 0.333, 0.500},
	{0.333, 0.667, 0.500},
	{0.333, 1.000, 0.500},
	{0.667, 0.000, 0.500},
	{0.667, 0.333, 0.500},
	{0.667, 0.667, 0.500},
	{0.667, 1.000, 0.500},
	{1.000, 0.000, 0.500},
	{1.000, 0.333, 0.500},
	{1.000, 0.667, 0.500},
	{1.000, 1.000, 0.500},
	{0.000, 0.333, 1.000},
	{0.000, 0.667, 1.000},
	{0.000, 1.000, 1.000},
	{0.333, 0.000, 1.000},
	{0.333, 0.333, 1.000},
	{0.333, 0.667, 1.000},
	{0.333, 1.000, 1.000},
	{0.667, 0.000, 1.000},
	{0.667, 0.333, 1.000},
	{0.667, 0.667, 1.000},
	{0.667, 1.000, 1.000},
	{1.000, 0.000, 1.000},
	{1.000, 0.333, 1.000},
	{1.000, 0.667, 1.000},
	{0.333, 0.000, 0.000},
	{0.500, 0.000, 0.000},
	{0.667, 0.000, 0.000},
	{0.833, 0.000, 0.000},
	{1.000, 0.000, 0.000},
	{0.000, 0.167, 0.000},
	{0.000, 0.333, 0.000},
	{0.000, 0.500, 0.000},
	{0.000, 0.667, 0.000},
	{0.000, 0.833, 0.000},
	{0.000, 1.000, 0.000},
	{0.000, 0.000, 0.167},
	{0.000, 0.000, 0.333},
	{0.000, 0.000, 0.500},
	{0.000, 0.000, 0.667},
	{0.000, 0.000, 0.833},
	{0.000, 0.000, 1.000},
	{0.000, 0.000, 0.000},
	{0.143, 0.143, 0.143},
	{0.286, 0.286, 0.286},
	{0.429, 0.429, 0.429},
	{0.571, 0.571, 0.571},
	{0.714, 0.714, 0.714},
	{0.857, 0.857, 0.857},
	{0.000, 0.447, 0.741},
	{0.314, 0.717, 0.741},
	{0.50, 0.5, 0}

static void DrawPred(Mat& img, std::vector<OutputSeg> result) {
	std::vector<Scalar> color;
	for (int i = 0; i < CLASSES; i++) {
		int b = rand() % 256;
		int g = rand() % 256;
		int r = rand() % 256;
		color.push_back(Scalar(b, g, r));
	Mat mask = img.clone();
	for (int i = 0; i < result.size(); i++) {
		int left, top;
		left = result[i].box.x;
		top = result[i].box.y;
		int color_num = i;
		rectangle(img, result[i].box, color[result[i].id], 2, 8);
		mask(result[i].box).setTo(color[result[i].id], result[i].boxMask);

		std::string label = std::to_string(result[i].id) + ":" + std::to_string(result[i].confidence);
		int baseLine;
		Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
		top = max(top, labelSize.height);
		putText(img, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 1, color[result[i].id], 2);
	addWeighted(img, 0.5, mask, 0.5, 0, img); //将mask加在原图上面


static void doInference(IExecutionContext& context, float* input, float* output, float* output1, int batchSize)
	const ICudaEngine& engine = context.getEngine();

	//判断该引擎是否有三个绑定,intput, output0, output1
	assert(engine.getNbBindings() == 3);
	void* buffers[3];

	const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
	const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);
	const int outputIndex1 = engine.getBindingIndex(OUTPUT_BLOB_NAME1);

	// 使用cudaMalloc分配了GPU内存。这些内存将用于存储模型的输入和输出
	CHECK(cudaMalloc(&buffers[inputIndex], INPUT_SIZE * sizeof(float)));//
	CHECK(cudaMalloc(&buffers[outputIndex], OUTPUT_SIZE * sizeof(float)));
	CHECK(cudaMalloc(&buffers[outputIndex1], OUTPUT_SIZE1 * sizeof(float)));
	// cudaMalloc分配内存 cudaFree释放内存 cudaMemcpy或 cudaMemcpyAsync 在主机和设备之间传输数据
	// cudaMemcpy cudaMemcpyAsync 显式地阻塞传输 显式地非阻塞传输 
	cudaStream_t stream;

	// 使用cudaMemcpyAsync将输入数据异步地复制到GPU缓冲区。这个操作是非阻塞的,意味着它不会立即完成。
	CHECK(cudaMemcpyAsync(buffers[inputIndex], input, INPUT_SIZE * sizeof(float), cudaMemcpyHostToDevice, stream));
	context.enqueue(batchSize, buffers, stream, nullptr);
	CHECK(cudaMemcpyAsync(output, buffers[outputIndex], OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
	CHECK(cudaMemcpyAsync(output1, buffers[outputIndex1], OUTPUT_SIZE1 * sizeof(float), cudaMemcpyDeviceToHost, stream));


class YOLO
	void init(std::string engine_path);
	void init(char* engine_path);
	void destroy();
	void blobFromImage(cv::Mat& img, float* data);
	void decode_boxs(cv::Mat& src, float* prob, OutputObject& outputObject, std::vector<int> padsize);
	void nms_outputs(cv::Mat& src, OutputObject& outputObject, std::vector<std::vector<float>>& temp_mask_proposals, std::vector<OutputSeg>& output);
	void decode_mask(cv::Mat& src, float* prob1, std::vector<int> padsize, std::vector<std::vector<float>>& temp_mask_proposals, std::vector<OutputSeg>& output);
	void drawMask(Mat& img, std::vector<OutputSeg> result);
	void detect_img(std::string image_path);
	void detect_img(char* image_path, float(*res_array)[6], uchar(*mask_array));

	ICudaEngine* engine;
	IRuntime* runtime;
	IExecutionContext* context;		

void YOLO::destroy()

void YOLO::init(std::string engine_path)
	//{ 0 }: 这是初始化列表,用于初始化 size 变量。在这种情况下,size 被初始化为 0。
	size_t size{ 0 };
	//定义一个指针变量,通过trtModelStream = new char[size];分配size个字符的空间
	char* trtModelStream{ nullptr };
	std::ifstream file(engine_path, std::ios::binary);
	if (file.good())
		file.seekg(0, file.end);
		size = file.tellg();
		file.seekg(0, file.beg);
		trtModelStream = new char[size]; //开辟一个char 长度是文件的长度
		file.read(trtModelStream, size);
	std::cout << "engine init finished" << std::endl;

	runtime = createInferRuntime(gLogger);
	assert(runtime != nullptr);
	engine = runtime->deserializeCudaEngine(trtModelStream, size);
	assert(engine != nullptr);
	context = engine->createExecutionContext();
	assert(context != nullptr);
	delete[] trtModelStream;

void YOLO::init(char* engine_path)
	//{ 0 }: 这是初始化列表,用于初始化 size 变量。在这种情况下,size 被初始化为 0。
	size_t size{ 0 };
	//定义一个指针变量,通过trtModelStream = new char[size];分配size个字符的空间
	char* trtModelStream{ nullptr };
	std::ifstream file(engine_path, std::ios::binary);
	if (file.good())
		file.seekg(0, file.end);
		size = file.tellg();
		file.seekg(0, file.beg);
		trtModelStream = new char[size]; //开辟一个char 长度是文件的长度
		file.read(trtModelStream, size);
	std::cout << "engine init finished" << std::endl;

	runtime = createInferRuntime(gLogger);
	assert(runtime != nullptr);
	engine = runtime->deserializeCudaEngine(trtModelStream, size);
	assert(engine != nullptr);
	context = engine->createExecutionContext();
	assert(context != nullptr);
	delete[] trtModelStream;

void YOLO::blobFromImage(cv::Mat& src, float* data)
	//float* data = new float[3 * INPUT_H * INPUT_W];
	int i = 0;// [1,3,INPUT_H,INPUT_W]	
	for (int row = 0; row < INPUT_H; ++row) 
		//pr_img.step=widthx3 就是每一行有width个3通道的值
		uchar* uc_pixel = src.data + row * src.step;
		for (int col = 0; col < INPUT_W; ++col)
			data[i] = (float)uc_pixel[2] / 255.0;
			data[i + INPUT_H * INPUT_W] = (float)uc_pixel[1] / 255.0;
			data[i + 2 * INPUT_H * INPUT_W] = (float)uc_pixel[0] / 255.0;
			uc_pixel += 3;//表示进行下一列

	//return data;

void YOLO::decode_boxs(cv::Mat& src, float* prob, OutputObject& outputObject, std::vector<int> padsize)
	int newh = padsize[0], neww = padsize[1], padh = padsize[2], padw = padsize[3];

	float ratio_h = (float)src.rows / newh;
	float ratio_w = (float)src.cols / neww;

	// 处理box
	int net_width = CLASSES + 5 + _segChannels;
	float* pdata = prob;
	for (int j = 0; j < Num_box; ++j) {
		float box_score = pdata[4]; ;//获取每一行的box框中含有某个物体的概率
		if (box_score >= CONF_THRESHOLD) {
			cv::Mat scores(1, CLASSES, CV_32FC1, pdata + 5);
			Point classIdPoint;
			double max_class_socre;
			minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint);
			max_class_socre = (float)max_class_socre;
			if (max_class_socre >= CONF_THRESHOLD) {

				std::vector<float> temp_proto(pdata + 5 + CLASSES, pdata + net_width);
				float x = (pdata[0] - padw) * ratio_w;  //x
				float y = (pdata[1] - padh) * ratio_h;  //y
				float w = pdata[2] * ratio_w;  //w
				float h = pdata[3] * ratio_h;  //h
				int left = MAX((x - 0.5 * w), 0);
				int top = MAX((y - 0.5 * h), 0);
				outputObject.confidences.push_back(max_class_socre * box_score);
				outputObject.boxes.push_back(Rect(left, top, int(w), int(h)));
		pdata += net_width;//下一行

void YOLO::nms_outputs(cv::Mat& src, OutputObject& outputObject, std::vector<std::vector<float>>& temp_mask_proposals, std::vector<OutputSeg>& output)
	std::vector<int> nms_result;
	cv::dnn::NMSBoxes(outputObject.boxes, outputObject.confidences, CONF_THRESHOLD, NMS_THRESHOLD, nms_result);
	//std::vector<std::vector<float>> temp_mask_proposals;
	Rect holeImgRect(0, 0, src.cols, src.rows);
	for (int i = 0; i < nms_result.size(); ++i) {
		int idx = nms_result[i];
		OutputSeg result;
		result.id = outputObject.classIds[idx];
		result.confidence = outputObject.confidences[idx];
		result.box = outputObject.boxes[idx] & holeImgRect;



void YOLO::decode_mask(cv::Mat& src, float* prob1, std::vector<int> padsize, std::vector<std::vector<float>>& temp_mask_proposals, std::vector<OutputSeg>& output)
	int newh = padsize[0], neww = padsize[1], padh = padsize[2], padw = padsize[3];

	// 处理mask
	Mat maskProposals;
	for (int i = 0; i < temp_mask_proposals.size(); ++i)
		//std::cout<< Mat(temp_mask_proposals[i]).t().size();

	float* pdata = prob1;
	std::vector<float> mask(pdata, pdata + _segChannels * _segWidth * _segHeight);

	Mat mask_protos = Mat(mask);
	Mat protos = mask_protos.reshape(0, { _segChannels,_segWidth * _segHeight });//将prob1的值 赋给mask_protos

	Mat matmulRes = (maskProposals * protos).t();//n*32 32*25600 A*B是以数学运算中矩阵相乘的方式实现的,要求A的列数等于B的行数时
	Mat masks = matmulRes.reshape(output.size(), { _segWidth,_segHeight });
	//std::cout << protos.size();
	std::vector<Mat> maskChannels;
	split(masks, maskChannels);
	for (int i = 0; i < output.size(); ++i) {
		Mat dest, mask;
		cv::exp(-maskChannels[i], dest);
		dest = 1.0 / (1.0 + dest);//160*160

		Rect roi(int((float)padw / INPUT_W * _segWidth), int((float)padh / INPUT_H * _segHeight), int(_segWidth - padw / 2), int(_segHeight - padh / 2));
		dest = dest(roi);
		resize(dest, mask, src.size(), INTER_NEAREST);

		Rect temp_rect = output[i].box;
		mask = mask(temp_rect) > MASK_THRESHOLD;

		output[i].boxMask = mask;


void YOLO::drawMask(Mat& img, std::vector<OutputSeg> result)
	std::vector<Scalar> color;
	for (int i = 0; i < CLASSES; i++) {
		int b = color_list[i][0] * 255;
		int g = color_list[i][1] * 255;
		int r = color_list[i][2] * 255;
		color.push_back(Scalar(b, g, r));
	Mat mask = img.clone();
	for (int i = 0; i < result.size(); i++) {		
		mask(result[i].box).setTo(color[result[i].id], result[i].boxMask);		
	addWeighted(img, 0.5, mask, 0.5, 0, img); //将mask加在原图上面


void YOLO::detect_img(std::string image_path)
	cv::Mat img = cv::imread(image_path);

	cv::Mat pr_img;
	std::vector<int> padsize;
	pr_img = preprocess_img(img, INPUT_H, INPUT_W, padsize);       // Resize

	float* blob = new float[3 * INPUT_H * INPUT_W];
    blobFromImage(pr_img, blob);

	auto start = std::chrono::system_clock::now();
	doInference(*context, blob, prob, prob1, batchSize);
	auto end = std::chrono::system_clock::now();
	std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;

	OutputObject outputObject;
	std::vector<OutputSeg> output;
	std::vector<std::vector<float>> temp_mask_proposals;
	decode_boxs(img, prob, outputObject, padsize);
	nms_outputs(img, outputObject, temp_mask_proposals, output);
	if (output.size() != 0)
		decode_mask(img, prob1, padsize, temp_mask_proposals, output);

	DrawPred(img, output);
	cv::imwrite("output.jpg", img);
	cv::imshow("output", img);

    delete[] blob;

/// <summary>
/// 读取图片进行推理,用数组将检测结果传出(包括label序号、置信度分数、矩形参数),用uchar数组将mask图片传出
/// </summary>
/// <param name="image_path"></param>
/// <param name="res_array"></param>
/// <param name="mask_array"></param>
void YOLO::detect_img(char* image_path, float(*res_array)[6], uchar(*mask_array))
	cv::Mat img = cv::imread(image_path);

	cv::Mat pr_img;
	std::vector<int> padsize;
	pr_img = preprocess_img(img, INPUT_H, INPUT_W, padsize);       // Resize

	float* blob = new float[3 * INPUT_H * INPUT_W];
    blobFromImage(pr_img, blob);

	auto start = std::chrono::system_clock::now();
	doInference(*context, blob, prob, prob1, batchSize);
	auto end = std::chrono::system_clock::now();
	std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;

	OutputObject outputObject;
	std::vector<OutputSeg> output;
	std::vector<std::vector<float>> temp_mask_proposals;
	decode_boxs(img, prob, outputObject, padsize);
	nms_outputs(img, outputObject, temp_mask_proposals, output);
	if (output.size() != 0)
		decode_mask(img, prob1, padsize, temp_mask_proposals, output);

	Mat mask = Mat(img.rows, img.cols, img.type(), Scalar(255, 255, 255));
	drawMask(mask, output);

	for (size_t j = 0; j < output.size(); j++)
		res_array[j][0] = output[j].box.x;
		res_array[j][1] = output[j].box.y;
		res_array[j][2] = output[j].box.width;
		res_array[j][3] = output[j].box.height;
		res_array[j][4] = output[j].id;
		res_array[j][5] = output[j].confidence;
		//mask_array = output[j].boxMask.data;

	for (int i = 0; i < mask.rows; i++) {
		for (int j = 0; j < mask.cols; j++) {
			for (int k = 0; k < 3; k++) {
				mask_array[i * mask.cols * 3 + j * 3 + k] = mask.at<cv::Vec3b>(i, j)[k];

	delete[] blob;



#include <iostream>
#include "yolo.hpp"

int main()
    YOLO yolo;   




#include <iostream>
#include "yolo.hpp"

//int main()
//    YOLO yolo;   
//    yolo.init("E:\\yolov5s-seg.engine");
//    yolo.detect_img("E:\\bus.jpg");
//    yolo.destroy();

YOLO yolo;
extern "C" __declspec(dllexport) void Init(char* engine_path)

extern "C" __declspec(dllexport) void Detect_Img(char* image_path, float(*res_array)[6], uchar(*mask_array))
    yolo.detect_img(image_path, res_array, mask_array);

extern "C" __declspec(dllexport) void Destroy()


 [DllImport("engine_infer_mask.dll", CallingConvention = CallingConvention.Cdecl)]
 public static extern void Init(string engine_path);

 [DllImport("engine_infer_mask.dll", CallingConvention = CallingConvention.Cdecl)]
 public static extern void Detect_Img(string, path, float[,] resultArray, ref byte classArray);

 [DllImport("engine_infer_mask.dll", CallingConvention = CallingConvention.Cdecl)]
 public static extern void Destroy();



         我们使用 Netron打开对应onnx模型,可以看到images和output0、output1参数与模型一一对应,一般我们训练自己的模型都会修改CLASSES,根据自己参数对应修改即可。置信度阈值也根据自己需要进行修改。





