paddleocr tensorrt推理


cmake_minimum_required(VERSION 3.24)



set(CMAKE_PREFIX_PATH "D:/Qt/5.15.2/msvc2019_64")
set(OpenCV_DIR D:/opencv/build/x64/vc16/lib)


find_package(Qt5 COMPONENTS

find_package(CUDA REQUIRED)
message(STATUS " libraries: ${CUDA_LIBRARIES}")
message(STATUS " include path: ${CUDA_INCLUDE_DIRS}")


add_executable(ocr main.cpp logger.cpp engine.cpp)

find_package(OpenCV REQUIRED)
target_link_libraries(ocr ${OpenCV_LIBS})

target_link_libraries(ocr "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7/lib/x64/nvinfer.lib")
target_link_libraries(ocr "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7/lib/x64/nvonnxparser.lib")
target_link_libraries(ocr "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7/lib/x64/cudart.lib")
target_link_libraries(ocr "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7/lib/x64/nvinfer_plugin.lib")

if (WIN32)
        set(DEBUG_SUFFIX "d")
    endif ()
    if (NOT EXISTS "${QT_INSTALL_PATH}/bin")
        set(QT_INSTALL_PATH "${QT_INSTALL_PATH}/..")
        if (NOT EXISTS "${QT_INSTALL_PATH}/bin")
            set(QT_INSTALL_PATH "${QT_INSTALL_PATH}/..")
        endif ()
    endif ()
    if (EXISTS "${QT_INSTALL_PATH}/plugins/platforms/qwindows${DEBUG_SUFFIX}.dll")
        add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
                COMMAND ${CMAKE_COMMAND} -E make_directory
        add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
                COMMAND ${CMAKE_COMMAND} -E copy
    endif ()
    foreach (QT_LIB Core Gui Widgets)
        add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
                COMMAND ${CMAKE_COMMAND} -E copy
    endforeach (QT_LIB)
endif ()


#include "NvInfer.h"
#include "common.h"
#include "half.h"
#include <cassert>
#include <cuda_runtime_api.h>
#include <iostream>
#include <iterator>
#include <memory>
#include <new>
#include <numeric>
#include <string>
#include <vector>

namespace samplesCommon

//! \brief  The GenericBuffer class is a templated class for buffers.
//! \details This templated RAII (Resource Acquisition Is Initialization) class handles the allocation,
//!          deallocation, querying of buffers on both the device and the host.
//!          It can handle data of arbitrary types because it stores byte buffers.
//!          The template parameters AllocFunc and FreeFunc are used for the
//!          allocation and deallocation of the buffer.
//!          AllocFunc must be a functor that takes in (void** ptr, size_t size)
//!          and returns bool. ptr is a pointer to where the allocated buffer address should be stored.
//!          size is the amount of memory in bytes to allocate.
//!          The boolean indicates whether or not the memory allocation was successful.
//!          FreeFunc must be a functor that takes in (void* ptr) and returns void.
//!          ptr is the allocated buffer address. It must work with nullptr input.
template <typename AllocFunc, typename FreeFunc>
class GenericBuffer
    //! \brief Construct an empty buffer.
    GenericBuffer(nvinfer1::DataType type = nvinfer1::DataType::kFLOAT)
        : mSize(0)
        , mCapacity(0)
        , mType(type)
        , mBuffer(nullptr)

    //! \brief Construct a buffer with the specified allocation size in bytes.
    GenericBuffer(size_t size, nvinfer1::DataType type)
        : mSize(size)
        , mCapacity(size)
        , mType(type)
        if (!allocFn(&mBuffer, this->nbBytes()))
            throw std::bad_alloc();

    GenericBuffer(GenericBuffer&& buf)
        : mSize(buf.mSize)
        , mCapacity(buf.mCapacity)
        , mType(buf.mType)
        , mBuffer(buf.mBuffer)
        buf.mSize = 0;
        buf.mCapacity = 0;
        buf.mType = nvinfer1::DataType::kFLOAT;
        buf.mBuffer = nullptr;

    GenericBuffer& operator=(GenericBuffer&& buf)
        if (this != &buf)
            mSize = buf.mSize;
            mCapacity = buf.mCapacity;
            mType = buf.mType;
            mBuffer = buf.mBuffer;
            // Reset buf.
            buf.mSize = 0;
            buf.mCapacity = 0;
            buf.mBuffer = nullptr;
        return *this;

    //! \brief Returns pointer to underlying array.
    void* data()
        return mBuffer;

    //! \brief Returns pointer to underlying array.
    const void* data() const
        return mBuffer;

    //! \brief Returns the size (in number of elements) of the buffer.
    size_t size() const
        return mSize;

    //! \brief Returns the size (in bytes) of the buffer.
    size_t nbBytes() const
        return this->size() * samplesCommon::getElementSize(mType);

    //! \brief Resizes the buffer. This is a no-op if the new size is smaller than or equal to the current capacity.
    void resize(size_t newSize)
        mSize = newSize;
        if (mCapacity < newSize)
            if (!allocFn(&mBuffer, this->nbBytes()))
                throw std::bad_alloc{};
            mCapacity = newSize;

    //! \brief Overload of resize that accepts Dims
    void resize(const nvinfer1::Dims& dims)
        return this->resize(samplesCommon::volume(dims));


    size_t mSize{0}, mCapacity{0};
    nvinfer1::DataType mType;
    void* mBuffer;
    AllocFunc allocFn;
    FreeFunc freeFn;

class DeviceAllocator
    bool operator()(void** ptr, size_t size) const
        return cudaMalloc(ptr, size) == cudaSuccess;

class DeviceFree
    void operator()(void* ptr) const

class HostAllocator
    bool operator()(void** ptr, size_t size) const
        *ptr = malloc(size);
        return *ptr != nullptr;

class HostFree
    void operator()(void* ptr) const

using DeviceBuffer = GenericBuffer<DeviceAllocator, DeviceFree>;
using HostBuffer = GenericBuffer<HostAllocator, HostFree>;

//! \brief  The ManagedBuffer class groups together a pair of corresponding device and host buffers.
class ManagedBuffer
    DeviceBuffer deviceBuffer;
    HostBuffer hostBuffer;

//! \brief  The BufferManager class handles host and device buffer allocation and deallocation.
//! \details This RAII class handles host and device buffer allocation and deallocation,
//!          memcpy between host and device buffers to aid with inference,
//!          and debugging dumps to validate inference. The BufferManager class is meant to be
//!          used to simplify buffer management and any interactions between buffers and the engine.
class BufferManager
    static const size_t kINVALID_SIZE_VALUE = ~size_t(0);

    //! \brief Create a BufferManager for handling buffer interactions with engine.
    BufferManager(std::shared_ptr<nvinfer1::ICudaEngine> engine, const int batchSize = 0,
        const nvinfer1::IExecutionContext* context = nullptr)
        : mEngine(engine)
        , mBatchSize(batchSize)
        // Full Dims implies no batch size.
        assert(engine->hasImplicitBatchDimension() || mBatchSize == 0);
        // Create host and device buffers
        for (int i = 0; i < mEngine->getNbBindings(); i++)
            auto dims = context ? context->getBindingDimensions(i) : mEngine->getBindingDimensions(i);
            size_t vol = context || !mBatchSize ? 1 : static_cast<size_t>(mBatchSize);
            nvinfer1::DataType type = mEngine->getBindingDataType(i);
            int vecDim = mEngine->getBindingVectorizedDim(i);
            if (-1 != vecDim) // i.e., 0 != lgScalarsPerVector
                int scalarsPerVec = mEngine->getBindingComponentsPerElement(i);
                dims.d[vecDim] = divUp(dims.d[vecDim], scalarsPerVec);
                vol *= scalarsPerVec;
            vol *= samplesCommon::volume(dims);
            std::unique_ptr<ManagedBuffer> manBuf{new ManagedBuffer()};
            manBuf->deviceBuffer = DeviceBuffer(vol, type);
            manBuf->hostBuffer = HostBuffer(vol, type);

    //! \brief Returns a vector of device buffers that you can use directly as
    //!        bindings for the execute and enqueue methods of IExecutionContext.
    std::vector<void*>& getDeviceBindings()
        return mDeviceBindings;

    //! \brief Returns a vector of device buffers.
    const std::vector<void*>& getDeviceBindings() const
        return mDeviceBindings;

    //! \brief Returns the device buffer corresponding to tensorName.
    //!        Returns nullptr if no such tensor can be found.
    void* getDeviceBuffer(const std::string& tensorName) const
        return getBuffer(false, tensorName);

    //! \brief Returns the host buffer corresponding to tensorName.
    //!        Returns nullptr if no such tensor can be found.
    void* getHostBuffer(const std::string& tensorName) const
        return getBuffer(true, tensorName);

    //! \brief Returns the size of the host and device buffers that correspond to tensorName.
    //!        Returns kINVALID_SIZE_VALUE if no such tensor can be found.
    size_t size(const std::string& tensorName) const
        int index = mEngine->getBindingIndex(tensorName.c_str());
        if (index == -1)
            return kINVALID_SIZE_VALUE;
        return mManagedBuffers[index]->hostBuffer.nbBytes();

    //! \brief Templated print function that dumps buffers of arbitrary type to std::ostream.
    //!        rowCount parameter controls how many elements are on each line.
    //!        A rowCount of 1 means that there is only 1 element on each line.
    template <typename T>
    void print(std::ostream& os, void* buf, size_t bufSize, size_t rowCount)
        assert(rowCount != 0);
        assert(bufSize % sizeof(T) == 0);
        T* typedBuf = static_cast<T*>(buf);
        size_t numItems = bufSize / sizeof(T);
        for (int i = 0; i < static_cast<int>(numItems); i++)
            // Handle rowCount == 1 case
            if (rowCount == 1 && i != static_cast<int>(numItems) - 1)
                os << typedBuf[i] << std::endl;
            else if (rowCount == 1)
                os << typedBuf[i];
            // Handle rowCount > 1 case
            else if (i % rowCount == 0)
                os << typedBuf[i];
            else if (i % rowCount == rowCount - 1)
                os << " " << typedBuf[i] << std::endl;
                os << " " << typedBuf[i];

    //! \brief Copy the contents of input host buffers to input device buffers synchronously.
    void copyInputToDevice()
        memcpyBuffers(true, false, false);

    //! \brief Copy the contents of output device buffers to output host buffers synchronously.
    void copyOutputToHost()
        memcpyBuffers(false, true, false);

    //! \brief Copy the contents of input host buffers to input device buffers asynchronously.
    void copyInputToDeviceAsync(const cudaStream_t& stream = 0)
        memcpyBuffers(true, false, true, stream);

    //! \brief Copy the contents of output device buffers to output host buffers asynchronously.
    void copyOutputToHostAsync(const cudaStream_t& stream = 0)
        memcpyBuffers(false, true, true, stream);

    ~BufferManager() = default;

    void* getBuffer(const bool isHost, const std::string& tensorName) const
        int index = mEngine->getBindingIndex(tensorName.c_str());
        if (index == -1)
            return nullptr;
        return (isHost ? mManagedBuffers[index]-> : mManagedBuffers[index]->;

    void memcpyBuffers(const bool copyInput, const bool deviceToHost, const bool async, const cudaStream_t& stream = 0)
        for (int i = 0; i < mEngine->getNbBindings(); i++)
            void* dstPtr
                = deviceToHost ? mManagedBuffers[i]-> : mManagedBuffers[i]->;
            const void* srcPtr
                = deviceToHost ? mManagedBuffers[i]-> : mManagedBuffers[i]->;
            const size_t byteSize = mManagedBuffers[i]->hostBuffer.nbBytes();
            const cudaMemcpyKind memcpyType = deviceToHost ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice;
            if ((copyInput && mEngine->bindingIsInput(i)) || (!copyInput && !mEngine->bindingIsInput(i)))
                if (async)
                    CHECK(cudaMemcpyAsync(dstPtr, srcPtr, byteSize, memcpyType, stream));
                    CHECK(cudaMemcpy(dstPtr, srcPtr, byteSize, memcpyType));

    std::shared_ptr<nvinfer1::ICudaEngine> mEngine;              //!< The pointer to the engine
    int mBatchSize;                                              //!< The batch size for legacy networks, 0 otherwise.
    std::vector<std::unique_ptr<ManagedBuffer>> mManagedBuffers; //!< The vector of pointers to managed buffers
    std::vector<void*> mDeviceBindings;                          //!< The vector of device buffers needed for engine execution

} // namespace samplesCommon



// For loadLibrary
#ifdef _MSC_VER
// Needed so that the max/min definitions in windows.h do not conflict with std::max/min.
#define NOMINMAX
#include <windows.h>
#include <dlfcn.h>

#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "logger.h"
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cmath>
#include <cstring>
#include <cuda_runtime_api.h>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <map>
#include <memory>
#include <new>
#include <numeric>
#include <ratio>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#ifdef _MSC_VER
#include <stdio.h>  // fileno
#include <unistd.h> // lockf

#include "safeCommon.h"

#ifdef _MSC_VER
#define FN_NAME __FUNCTION__
#define FN_NAME __func__

#if defined(__aarch64__) || defined(__QNX__)
#define ENABLE_DLA_API 1

#define CHECK_RETURN_W_MSG(status, val, errMsg)                                                                        \
    do                                                                                                                 \
    {                                                                                                                  \
        if (!(status))                                                                                                 \
        {                                                                                                              \
            sample::gLogError << errMsg << " Error in " << __FILE__ << ", function " << FN_NAME << "(), line " << __LINE__     \
                      << std::endl;                                                                                    \
            return val;                                                                                                \
        }                                                                                                              \
    } while (0)

#undef ASSERT
#define ASSERT(condition)                                                   \
    do                                                                      \
    {                                                                       \
        if (!(condition))                                                   \
        {                                                                   \
            sample::gLogError << "Assertion failure: " << #condition << std::endl;  \
            abort();                                                        \
        }                                                                   \
    } while (0)

#define CHECK_RETURN(status, val) CHECK_RETURN_W_MSG(status, val, "")

#define OBJ_GUARD(A) std::unique_ptr<A, void (*)(A * t)>

template <typename T, typename T_>
makeObjGuard(T_* t)
    CHECK(!(std::is_base_of<T, T_>::value || std::is_same<T, T_>::value));
    auto deleter = [](T* t) { t->destroy(); };
    return std::unique_ptr<T, decltype(deleter)>{static_cast<T*>(t), deleter};

constexpr long double operator"" _GiB(long double val)
    return val * (1 << 30);
constexpr long double operator"" _MiB(long double val)
    return val * (1 << 20);
constexpr long double operator"" _KiB(long double val)
    return val * (1 << 10);

struct SimpleProfiler : public nvinfer1::IProfiler
    struct Record
        float time{0};
        int count{0};

    virtual void reportLayerTime(const char* layerName, float ms) noexcept
        mProfile[layerName].time += ms;
        if (std::find(mLayerNames.begin(), mLayerNames.end(), layerName) == mLayerNames.end())

    SimpleProfiler(const char* name, const std::vector<SimpleProfiler>& srcProfilers = std::vector<SimpleProfiler>())
        : mName(name)
        for (const auto& srcProfiler : srcProfilers)
            for (const auto& rec : srcProfiler.mProfile)
                auto it = mProfile.find(rec.first);
                if (it == mProfile.end())
                    it->second.time += rec.second.time;
                    it->second.count += rec.second.count;

    friend std::ostream& operator<<(std::ostream& out, const SimpleProfiler& value)
        out << "========== " << value.mName << " profile ==========" << std::endl;
        float totalTime = 0;
        std::string layerNameStr = "TensorRT layer name";
        int maxLayerNameLength = std::max(static_cast<int>(layerNameStr.size()), 70);
        for (const auto& elem : value.mProfile)
            totalTime += elem.second.time;
            maxLayerNameLength = std::max(maxLayerNameLength, static_cast<int>(elem.first.size()));

        auto old_settings = out.flags();
        auto old_precision = out.precision();
        // Output header
            out << std::setfill(' ') << std::setw(maxLayerNameLength) << layerNameStr << " ";
            out << std::setw(12) << "Runtime, "
                << "%"
                << " ";
            out << std::setw(12) << "Invocations"
                << " ";
            out << std::setw(12) << "Runtime, ms" << std::endl;
        for (size_t i = 0; i < value.mLayerNames.size(); i++)
            const std::string layerName = value.mLayerNames[i];
            auto elem =;
            out << std::setw(maxLayerNameLength) << layerName << " ";
            out << std::setw(12) << std::fixed << std::setprecision(1) << (elem.time * 100.0F / totalTime) << "%"
                << " ";
            out << std::setw(12) << elem.count << " ";
            out << std::setw(12) << std::fixed << std::setprecision(2) << elem.time << std::endl;
        out << "========== " << value.mName << " total runtime = " << totalTime << " ms ==========" << std::endl;

        return out;

    std::string mName;
    std::vector<std::string> mLayerNames;
    std::map<std::string, Record> mProfile;

//! Locate path to file, given its filename or filepath suffix and possible dirs it might lie in.
//! Function will also walk back MAX_DEPTH dirs from CWD to check for such a file path.
inline std::string locateFile(
    const std::string& filepathSuffix, const std::vector<std::string>& directories, bool reportError = true)
    const int MAX_DEPTH{10};
    bool found{false};
    std::string filepath;

    for (auto& dir : directories)
        if (!dir.empty() && dir.back() != '/')
#ifdef _MSC_VER
            filepath = dir + "\\" + filepathSuffix;
            filepath = dir + "/" + filepathSuffix;
            filepath = dir + filepathSuffix;

        for (int i = 0; i < MAX_DEPTH && !found; i++)
            const std::ifstream checkFile(filepath);
            found = checkFile.is_open();
            if (found)

            filepath = "../" + filepath; // Try again in parent dir

        if (found)


    // Could not find the file
    if (filepath.empty())
        const std::string dirList = std::accumulate(directories.begin() + 1, directories.end(), directories.front(),
            [](const std::string& a, const std::string& b) { return a + "\n\t" + b; });
        std::cout << "Could not find " << filepathSuffix << " in data directories:\n\t" << dirList << std::endl;

        if (reportError)
            std::cout << "&&&& FAILED" << std::endl;

    return filepath;

inline void readPGMFile(const std::string& fileName, uint8_t* buffer, int inH, int inW)
    std::ifstream infile(fileName, std::ifstream::binary);
    assert(infile.is_open() && "Attempting to read from a file that is not open.");
    std::string magic, h, w, max;
    infile >> magic >> h >> w >> max;
    infile.seekg(1, infile.cur);<char*>(buffer), inH * inW);

namespace samplesCommon

// Swaps endianness of an integral type.
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline T swapEndianness(const T& value)
    uint8_t bytes[sizeof(T)];
    for (int i = 0; i < static_cast<int>(sizeof(T)); ++i)
        bytes[sizeof(T) - 1 - i] = *(reinterpret_cast<const uint8_t*>(&value) + i);
    return *reinterpret_cast<T*>(bytes);

class HostMemory
    HostMemory() = delete;
    virtual void* data() const noexcept
        return mData;
    virtual std::size_t size() const noexcept
        return mSize;
    virtual nvinfer1::DataType type() const noexcept
        return mType;
    virtual ~HostMemory() {}

    HostMemory(std::size_t size, nvinfer1::DataType type)
        : mData{nullptr}
        , mSize(size)
        , mType(type)
    void* mData;
    std::size_t mSize;
    nvinfer1::DataType mType;

template <typename ElemType, nvinfer1::DataType dataType>
class TypedHostMemory : public HostMemory
    explicit TypedHostMemory(std::size_t size)
        : HostMemory(size, dataType)
        mData = new ElemType[size];
    ~TypedHostMemory() noexcept
        delete[](ElemType*) mData;
    ElemType* raw() noexcept
        return static_cast<ElemType*>(data());

using FloatMemory = TypedHostMemory<float, nvinfer1::DataType::kFLOAT>;
using HalfMemory = TypedHostMemory<uint16_t, nvinfer1::DataType::kHALF>;
using ByteMemory = TypedHostMemory<uint8_t, nvinfer1::DataType::kINT8>;

inline void* safeCudaMalloc(size_t memSize)
    void* deviceMem;
    CHECK(cudaMalloc(&deviceMem, memSize));
    if (deviceMem == nullptr)
        std::cerr << "Out of memory" << std::endl;
    return deviceMem;

inline bool isDebug()
    return (std::getenv("TENSORRT_DEBUG") ? true : false);

struct InferDeleter
    template <typename T>
    void operator()(T* obj) const
        delete obj;

template <typename T>
using SampleUniquePtr = std::unique_ptr<T, InferDeleter>;

static auto StreamDeleter = [](cudaStream_t* pStream)
        if (pStream)
            delete pStream;

inline std::unique_ptr<cudaStream_t, decltype(StreamDeleter)> makeCudaStream()
    std::unique_ptr<cudaStream_t, decltype(StreamDeleter)> pStream(new cudaStream_t, StreamDeleter);
    if (cudaStreamCreateWithFlags(pStream.get(), cudaStreamNonBlocking) != cudaSuccess)

    return pStream;

//! Return vector of indices that puts magnitudes of sequence in descending order.
template <class Iter>
std::vector<size_t> argMagnitudeSort(Iter begin, Iter end)
    std::vector<size_t> indices(end - begin);
    std::iota(indices.begin(), indices.end(), 0);
    std::sort(indices.begin(), indices.end(), [&begin](size_t i, size_t j) { return std::abs(begin[j]) < std::abs(begin[i]); });
    return indices;

inline bool readReferenceFile(const std::string& fileName, std::vector<std::string>& refVector)
    std::ifstream infile(fileName);
    if (!infile.is_open())
        std::cout << "ERROR: readReferenceFile: Attempting to read from a file that is not open." << std::endl;
        return false;
    std::string line;
    while (std::getline(infile, line))
        if (line.empty())
    return true;

template <typename T>
std::vector<std::string> classify(
    const std::vector<std::string>& refVector, const std::vector<T>& output, const size_t topK)
    const auto inds = samplesCommon::argMagnitudeSort(output.cbegin(), output.cend());
    std::vector<std::string> result;
    for (size_t k = 0; k < topK; ++k)
    return result;

// Returns indices of highest K magnitudes in v.
template <typename T>
std::vector<size_t> topKMagnitudes(const std::vector<T>& v, const size_t k)
    std::vector<size_t> indices = samplesCommon::argMagnitudeSort(v.cbegin(), v.cend());
    return indices;

template <typename T>
bool readASCIIFile(const std::string& fileName, const size_t size, std::vector<T>& out)
    std::ifstream infile(fileName);
    if (!infile.is_open())
        std::cout << "ERROR readASCIIFile: Attempting to read from a file that is not open." << std::endl;
        return false;
    out.assign(std::istream_iterator<T>(infile), std::istream_iterator<T>());
    return true;

template <typename T>
bool writeASCIIFile(const std::string& fileName, const std::vector<T>& in)
    std::ofstream outfile(fileName);
    if (!outfile.is_open())
        std::cout << "ERROR: writeASCIIFile: Attempting to write to a file that is not open." << std::endl;
        return false;
    for (auto fn : in)
        outfile << fn << "\n";
    return true;

inline void print_version()
    std::cout << "  TensorRT version: " << NV_TENSORRT_MAJOR << "." << NV_TENSORRT_MINOR << "." << NV_TENSORRT_PATCH
              << "." << NV_TENSORRT_BUILD << std::endl;

inline std::string getFileType(const std::string& filepath)
    return filepath.substr(filepath.find_last_of(".") + 1);

inline std::string toLower(const std::string& inp)
    std::string out = inp;
    std::transform(out.begin(), out.end(), out.begin(), ::tolower);
    return out;

inline float getMaxValue(const float* buffer, int64_t size)
    assert(buffer != nullptr);
    assert(size > 0);
    return *std::max_element(buffer, buffer + size);

inline int32_t calculateSoftmax(float* const prob, int32_t const numDigits)
    ASSERT(prob != nullptr);
    ASSERT(numDigits == 10);
    float sum{0.0F};
    std::transform(prob, prob + numDigits, prob, [&sum](float v) -> float {
        sum += exp(v);
        return exp(v);

    ASSERT(sum != 0.0F);
    std::transform(prob, prob + numDigits, prob, [sum](float v) -> float { return v / sum; });
    int32_t idx = std::max_element(prob, prob + numDigits) - prob;
    return idx;

// Ensures that every tensor used by a network has a dynamic range set.
// All tensors in a network must have a dynamic range specified if a calibrator is not used.
// This function is just a utility to globally fill in missing scales and zero-points for the entire network.
// If a tensor does not have a dyanamic range set, it is assigned inRange or outRange as follows:
// * If the tensor is the input to a layer or output of a pooling node, its dynamic range is derived from inRange.
// * Otherwise its dynamic range is derived from outRange.
// The default parameter values are intended to demonstrate, for final layers in the network,
// cases where dynamic ranges are asymmetric.
// The default parameter values choosen arbitrarily. Range values should be choosen such that
// we avoid underflow or overflow. Also range value should be non zero to avoid uniform zero scale tensor.
inline void setAllDynamicRanges(nvinfer1::INetworkDefinition* network, float inRange = 2.0f, float outRange = 4.0f)
    // Ensure that all layer inputs have a scale.
    for (int i = 0; i < network->getNbLayers(); i++)
        auto layer = network->getLayer(i);
        for (int j = 0; j < layer->getNbInputs(); j++)
            nvinfer1::ITensor* input{layer->getInput(j)};
            // Optional inputs are nullptr here and are from RNN layers.
            if (input != nullptr && !input->dynamicRangeIsSet())
                ASSERT(input->setDynamicRange(-inRange, inRange));

    // Ensure that all layer outputs have a scale.
    // Tensors that are also inputs to layers are ingored here
    // since the previous loop nest assigned scales to them.
    for (int i = 0; i < network->getNbLayers(); i++)
        auto layer = network->getLayer(i);
        for (int j = 0; j < layer->getNbOutputs(); j++)
            nvinfer1::ITensor* output{layer->getOutput(j)};
            // Optional outputs are nullptr here and are from RNN layers.
            if (output != nullptr && !output->dynamicRangeIsSet())
                // Pooling must have the same input and output scales.
                if (layer->getType() == nvinfer1::LayerType::kPOOLING)
                    ASSERT(output->setDynamicRange(-inRange, inRange));
                    ASSERT(output->setDynamicRange(-outRange, outRange));

inline void setDummyInt8DynamicRanges(const nvinfer1::IBuilderConfig* c, nvinfer1::INetworkDefinition* n)
    // Set dummy per-tensor dynamic range if Int8 mode is requested.
    if (c->getFlag(nvinfer1::BuilderFlag::kINT8))
        sample::gLogWarning << "Int8 calibrator not provided. Generating dummy per-tensor dynamic range. Int8 accuracy "
                               "is not guaranteed."
                            << std::endl;

inline void enableDLA(
    nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, int useDLACore, bool allowGPUFallback = true)
    if (useDLACore >= 0)
        if (builder->getNbDLACores() == 0)
            std::cerr << "Trying to use DLA core " << useDLACore << " on a platform that doesn't have any DLA cores"
                      << std::endl;
            assert("Error: use DLA core on a platfrom that doesn't have any DLA cores" && false);
        if (allowGPUFallback)
        if (!config->getFlag(nvinfer1::BuilderFlag::kINT8))
            // User has not requested INT8 Mode.
            // By default run in FP16 mode. FP32 mode is not permitted.

inline int32_t parseDLA(int32_t argc, char** argv)
    for (int32_t i = 1; i < argc; i++)
        if (strncmp(argv[i], "--useDLACore=", 13) == 0)
            return std::stoi(argv[i] + 13);
    return -1;

inline uint32_t getElementSize(nvinfer1::DataType t) noexcept
    switch (t)
    case nvinfer1::DataType::kINT32: return 4;
    case nvinfer1::DataType::kFLOAT: return 4;
    case nvinfer1::DataType::kHALF: return 2;
    case nvinfer1::DataType::kBOOL:
    case nvinfer1::DataType::kUINT8:
    case nvinfer1::DataType::kINT8: return 1;
    return 0;

inline int64_t volume(nvinfer1::Dims const& dims, int32_t start, int32_t stop)
    ASSERT(start >= 0);
    ASSERT(start <= stop);
    ASSERT(stop <= dims.nbDims);
    ASSERT(std::all_of(dims.d + start, dims.d + stop, [](int32_t x) { return x >= 0; }));
    return std::accumulate(dims.d + start, dims.d + stop, int64_t{1}, std::multiplies<int64_t>{});

template <int C, int H, int W>
struct PPM
    std::string magic, fileName;
    int h, w, max;
    uint8_t buffer[C * H * W];

// New vPPM(variable sized PPM) class with variable dimensions.
struct vPPM
    std::string magic, fileName;
    int h, w, max;
    std::vector<uint8_t> buffer;

struct BBox
    float x1, y1, x2, y2;

template <int C, int H, int W>
void readPPMFile(const std::string& filename, samplesCommon::PPM<C, H, W>& ppm)
    ppm.fileName = filename;
    std::ifstream infile(filename, std::ifstream::binary);
    assert(infile.is_open() && "Attempting to read from a file that is not open.");
    infile >> ppm.magic >> ppm.w >> ppm.h >> ppm.max;
    infile.seekg(1, infile.cur);<char*>(ppm.buffer), ppm.w * ppm.h * 3);

inline void readPPMFile(const std::string& filename, vPPM& ppm, std::vector<std::string>& input_dir)
    ppm.fileName = filename;
    std::ifstream infile(locateFile(filename, input_dir), std::ifstream::binary);
    infile >> ppm.magic >> ppm.w >> ppm.h >> ppm.max;
    infile.seekg(1, infile.cur);

    for (int i = 0; i < ppm.w * ppm.h * 3; ++i)
    }<char*>(&ppm.buffer[0]), ppm.w * ppm.h * 3);

template <int C, int H, int W>
void writePPMFileWithBBox(const std::string& filename, PPM<C, H, W>& ppm, const BBox& bbox)
    std::ofstream outfile("./" + filename, std::ofstream::binary);
    outfile << "P6"
            << "\n"
            << ppm.w << " " << ppm.h << "\n"
            << ppm.max << "\n";

    auto round = [](float x) -> int { return int(std::floor(x + 0.5f)); };
    const int x1 = std::min(std::max(0, round(int(bbox.x1))), W - 1);
    const int x2 = std::min(std::max(0, round(int(bbox.x2))), W - 1);
    const int y1 = std::min(std::max(0, round(int(bbox.y1))), H - 1);
    const int y2 = std::min(std::max(0, round(int(bbox.y2))), H - 1);

    for (int x = x1; x <= x2; ++x)
        // bbox top border
        ppm.buffer[(y1 * ppm.w + x) * 3] = 255;
        ppm.buffer[(y1 * ppm.w + x) * 3 + 1] = 0;
        ppm.buffer[(y1 * ppm.w + x) * 3 + 2] = 0;
        // bbox bottom border
        ppm.buffer[(y2 * ppm.w + x) * 3] = 255;
        ppm.buffer[(y2 * ppm.w + x) * 3 + 1] = 0;
        ppm.buffer[(y2 * ppm.w + x) * 3 + 2] = 0;

    for (int y = y1; y <= y2; ++y)
        // bbox left border
        ppm.buffer[(y * ppm.w + x1) * 3] = 255;
        ppm.buffer[(y * ppm.w + x1) * 3 + 1] = 0;
        ppm.buffer[(y * ppm.w + x1) * 3 + 2] = 0;
        // bbox right border
        ppm.buffer[(y * ppm.w + x2) * 3] = 255;
        ppm.buffer[(y * ppm.w + x2) * 3 + 1] = 0;
        ppm.buffer[(y * ppm.w + x2) * 3 + 2] = 0;

    outfile.write(reinterpret_cast<char*>(ppm.buffer), ppm.w * ppm.h * 3);

inline void writePPMFileWithBBox(const std::string& filename, vPPM ppm, std::vector<BBox>& dets)
    std::ofstream outfile("./" + filename, std::ofstream::binary);
    outfile << "P6"
            << "\n"
            << ppm.w << " " << ppm.h << "\n"
            << ppm.max << "\n";
    auto round = [](float x) -> int { return int(std::floor(x + 0.5f)); };

    for (auto bbox : dets)
        for (int x = int(bbox.x1); x < int(bbox.x2); ++x)
            // bbox top border
            ppm.buffer[(round(bbox.y1) * ppm.w + x) * 3] = 255;
            ppm.buffer[(round(bbox.y1) * ppm.w + x) * 3 + 1] = 0;
            ppm.buffer[(round(bbox.y1) * ppm.w + x) * 3 + 2] = 0;
            // bbox bottom border
            ppm.buffer[(round(bbox.y2) * ppm.w + x) * 3] = 255;
            ppm.buffer[(round(bbox.y2) * ppm.w + x) * 3 + 1] = 0;
            ppm.buffer[(round(bbox.y2) * ppm.w + x) * 3 + 2] = 0;

        for (int y = int(bbox.y1); y < int(bbox.y2); ++y)
            // bbox left border
            ppm.buffer[(y * ppm.w + round(bbox.x1)) * 3] = 255;
            ppm.buffer[(y * ppm.w + round(bbox.x1)) * 3 + 1] = 0;
            ppm.buffer[(y * ppm.w + round(bbox.x1)) * 3 + 2] = 0;
            // bbox right border
            ppm.buffer[(y * ppm.w + round(bbox.x2)) * 3] = 255;
            ppm.buffer[(y * ppm.w + round(bbox.x2)) * 3 + 1] = 0;
            ppm.buffer[(y * ppm.w + round(bbox.x2)) * 3 + 2] = 0;

    outfile.write(reinterpret_cast<char*>(&ppm.buffer[0]), ppm.w * ppm.h * 3);

class TimerBase
    virtual void start() {}
    virtual void stop() {}
    float microseconds() const noexcept
        return mMs * 1000.f;
    float milliseconds() const noexcept
        return mMs;
    float seconds() const noexcept
        return mMs / 1000.f;
    void reset() noexcept
        mMs = 0.f;

    float mMs{0.0f};

class GpuTimer : public TimerBase
    explicit GpuTimer(cudaStream_t stream)
        : mStream(stream)
    void start()
        CHECK(cudaEventRecord(mStart, mStream));
    void stop()
        CHECK(cudaEventRecord(mStop, mStream));
        float ms{0.0f};
        CHECK(cudaEventElapsedTime(&ms, mStart, mStop));
        mMs += ms;

    cudaEvent_t mStart, mStop;
    cudaStream_t mStream;
}; // class GpuTimer

template <typename Clock>
class CpuTimer : public TimerBase
    using clock_type = Clock;

    void start()
        mStart = Clock::now();
    void stop()
        mStop = Clock::now();
        mMs += std::chrono::duration<float, std::milli>{mStop - mStart}.count();

    std::chrono::time_point<Clock> mStart, mStop;
}; // class CpuTimer

using PreciseCpuTimer = CpuTimer<std::chrono::high_resolution_clock>;

inline std::vector<std::string> splitString(std::string str, char delimiter = ',')
    std::vector<std::string> splitVect;
    std::stringstream ss(str);
    std::string substr;

    while (ss.good())
        getline(ss, substr, delimiter);
    return splitVect;

inline int getC(nvinfer1::Dims const& d)
    return d.nbDims >= 3 ? d.d[d.nbDims - 3] : 1;

inline int getH(const nvinfer1::Dims& d)
    return d.nbDims >= 2 ? d.d[d.nbDims - 2] : 1;

inline int getW(const nvinfer1::Dims& d)
    return d.nbDims >= 1 ? d.d[d.nbDims - 1] : 1;

inline void loadLibrary(const std::string& path)
#ifdef _MSC_VER
    void* handle = LoadLibrary(path.c_str());
    int32_t flags{RTLD_LAZY};
    // asan doesn't handle module unloading correctly and there are no plans on doing
    // so. In order to get proper stack traces, don't delete the shared library on
    // close so that asan can resolve the symbols correctly.
    flags |= RTLD_NODELETE;
#endif // ENABLE_ASAN

    void* handle = dlopen(path.c_str(), flags);
    if (handle == nullptr)
#ifdef _MSC_VER
        sample::gLogError << "Could not load plugin library: " << path << std::endl;
        sample::gLogError << "Could not load plugin library: " << path << ", due to: " << dlerror() << std::endl;

inline int32_t getSMVersion()
    int32_t deviceIndex = 0;

    int32_t major, minor;
    CHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, deviceIndex));
    CHECK(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, deviceIndex));

    return ((major << 8) | minor);

inline bool isSMSafe()
    const int32_t smVersion = getSMVersion();
    return smVersion == 0x0700 || smVersion == 0x0702 || smVersion == 0x0705 || smVersion == 0x0800
        || smVersion == 0x0806 || smVersion == 0x0807;

inline int32_t getMaxPersistentCacheSize()
    int32_t deviceIndex{};

    int32_t maxPersistentL2CacheSize;
#if CUDART_VERSION >= 11030
    CHECK(cudaDeviceGetAttribute(&maxPersistentL2CacheSize, cudaDevAttrMaxPersistingL2CacheSize, deviceIndex));
    maxPersistentL2CacheSize = 0;

    return maxPersistentL2CacheSize;

inline bool isDataTypeSupported(nvinfer1::DataType dataType)
    auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
    if (!builder)
        return false;

    if ((dataType == nvinfer1::DataType::kINT8 && !builder->platformHasFastInt8())
        || (dataType == nvinfer1::DataType::kHALF && !builder->platformHasFastFp16()))
        return false;

    return true;

class FileLock
    FileLock(std::string const& fileName)
        : fileName(fileName)
        std::string lockFileName = fileName + ".lock";
#ifdef _MSC_VER
        sample::gLogVerbose << "Trying to set exclusive file lock " << lockFileName << std::endl;
        auto startTime = std::chrono::high_resolution_clock::now();
        // MS docs said this is a blocking IO if "FILE_FLAG_OVERLAPPED" is not provided
        lock = CreateFileA(lockFileName.c_str(), GENERIC_WRITE, 0, NULL, OPEN_ALWAYS, 0, NULL);
        if (lock != INVALID_HANDLE_VALUE)
            float const time
                = std::chrono::duration<float>(std::chrono::high_resolution_clock::now() - startTime).count();
            sample::gLogVerbose << "File locked in " << time << " seconds." << std::endl;
            throw std::runtime_error("Failed to lock " + lockFileName + "!");
#elif defined(__QNX__)
        // We once enabled the file lock on QNX, lockf(F_TLOCK) return -1 and the reported error is
        // The error generated was 89
        // That means : Function not implemented
        fp = fopen(lockFileName.c_str(), "wb+");
        if (!fp)
            throw std::runtime_error("Cannot open " + lockFileName + "!");
        fd = fileno(fp);
        sample::gLogVerbose << "Trying to set exclusive file lock " << lockFileName << std::endl;
        auto startTime = std::chrono::high_resolution_clock::now();
        auto ret = lockf(fd, F_LOCK, 0);
        if (ret != 0)
            fd = -1;
            throw std::runtime_error("Failed to lock " + lockFileName + "!");
        float const time = std::chrono::duration<float>(std::chrono::high_resolution_clock::now() - startTime).count();
        sample::gLogVerbose << "File locked in " << time << " seconds." << std::endl;

        std::string lockFileName = fileName + ".lock";
#ifdef _MSC_VER
        if (lock != INVALID_HANDLE_VALUE)
            sample::gLogVerbose << "Trying to remove exclusive file lock " << lockFileName << std::endl;
            auto startTime = std::chrono::high_resolution_clock::now();
            float const time
                = std::chrono::duration<float>(std::chrono::high_resolution_clock::now() - startTime).count();
            sample::gLogVerbose << "File unlocked in " << time << " seconds." << std::endl;
#elif defined(__QNX__)
        // We once enabled the file lock on QNX, lockf(F_TLOCK) return -1 and the reported error is
        // The error generated was 89
        // That means : Function not implemented
        if (fd != -1)
            sample::gLogVerbose << "Trying to remove exclusive file lock " << lockFileName << std::endl;
            auto startTime = std::chrono::high_resolution_clock::now();
            auto ret = lockf(fd, F_ULOCK, 0);
            if (ret != 0)
                sample::gLogVerbose << "Failed to unlock " << lockFileName << "!" << std::endl;
                fd = -1;
                float const time
                    = std::chrono::duration<float>(std::chrono::high_resolution_clock::now() - startTime).count();
                sample::gLogVerbose << "File unlocked in " << time << " seconds." << std::endl;

    FileLock() = delete;                           // no default ctor
    FileLock(FileLock const&) = delete;            // no copy ctor
    FileLock& operator=(FileLock const&) = delete; // no copy assignment

    const std::string fileName; // the file being protected
#ifdef _MSC_VER
    HANDLE lock;
    FILE* fp;
    int32_t fd;

inline std::vector<char> loadTimingCacheFile(const std::string inFileName)
    std::unique_ptr<samplesCommon::FileLock> fileLock{new samplesCommon::FileLock(inFileName)};
    std::ifstream iFile(inFileName, std::ios::in | std::ios::binary);
    if (!iFile)
        sample::gLogWarning << "Could not read timing cache from: " << inFileName
                            << ". A new timing cache will be generated and written." << std::endl;
        return std::vector<char>();
    iFile.seekg(0, std::ifstream::end);
    size_t fsize = iFile.tellg();
    iFile.seekg(0, std::ifstream::beg);
    std::vector<char> content(fsize);, fsize);
    sample::gLogInfo << "Loaded " << fsize << " bytes of timing cache from " << inFileName << std::endl;
    return content;

inline void saveTimingCacheFile(const std::string outFileName, const nvinfer1::IHostMemory* blob)
    std::unique_ptr<samplesCommon::FileLock> fileLock{new samplesCommon::FileLock(outFileName)};
    std::ofstream oFile(outFileName, std::ios::out | std::ios::binary);
    if (!oFile)
        sample::gLogWarning << "Could not write timing cache to: " << outFileName << std::endl;
    oFile.write((char*) blob->data(), blob->size());
    sample::gLogInfo << "Saved " << blob->size() << " bytes of timing cache to " << outFileName << std::endl;

inline void updateTimingCacheFile(std::string const fileName, nvinfer1::ITimingCache const* timingCache)
    // Prepare empty timingCache in case that there is no existing file to read
    std::unique_ptr<nvinfer1::IBuilder> builder{nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger())};
    std::unique_ptr<nvinfer1::IBuilderConfig> config{builder->createBuilderConfig()};
    std::unique_ptr<nvinfer1::ITimingCache> fileTimingCache{config->createTimingCache(static_cast<const void*>(nullptr), 0)};

    std::unique_ptr<samplesCommon::FileLock> fileLock{new samplesCommon::FileLock(fileName)};
    std::ifstream iFile(fileName, std::ios::in | std::ios::binary);
    if (iFile)
        iFile.seekg(0, std::ifstream::end);
        size_t fsize = iFile.tellg();
        iFile.seekg(0, std::ifstream::beg);
        std::vector<char> content(fsize);, fsize);
        sample::gLogInfo << "Loaded " << fsize << " bytes of timing cache from " << fileName << std::endl;
        fileTimingCache.reset(config->createTimingCache(static_cast<const void*>(, content.size()));
        if (!fileTimingCache)
            throw std::runtime_error("Failed to create timingCache from " + fileName + "!");
    fileTimingCache->combine(*timingCache, false);
    std::unique_ptr<nvinfer1::IHostMemory> blob{fileTimingCache->serialize()};
    if (!blob)
        throw std::runtime_error("Failed to serialize ITimingCache!");
    std::ofstream oFile(fileName, std::ios::out | std::ios::binary);
    if (!oFile)
        sample::gLogWarning << "Could not write timing cache to: " << fileName << std::endl;
    oFile.write((char*) blob->data(), blob->size());
    sample::gLogInfo << "Saved " << blob->size() << " bytes of timing cache to " << fileName << std::endl;

} // namespace samplesCommon

inline std::ostream& operator<<(std::ostream& os, const nvinfer1::Dims& dims)
    os << "(";
    for (int i = 0; i < dims.nbDims; ++i)
        os << (i ? ", " : "") << dims.d[i];
    return os << ")";



#include <iostream>
#include <fstream>

#include "engine.h"
#include "NvOnnxParser.h"

using namespace nvinfer1;
using namespace nvonnxparser;

void Normalize(cv::Mat *im,
               const std::vector<float> &mean = {0.5f, 0.5f, 0.5f},
               const std::vector<float> &scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f},
               const bool is_scale = true)
    double e = 1.0;
    if (is_scale)
        e /= 255.0;
    (*im).convertTo(*im, CV_32FC3, e);
    std::vector<cv::Mat> bgr_channels(3);
    cv::split(*im, bgr_channels);
    for (auto i = 0; i < bgr_channels.size(); i++)
        bgr_channels[i].convertTo(bgr_channels[i], CV_32FC1, 1.0 * scale[i],
                                  (0.0 - mean[i]) * scale[i]);
    cv::merge(bgr_channels, *im);

void Logger::log(Severity severity, const char *msg) noexcept
    // Would advise using a proper logging utility such as
    // For the sake of this tutorial, will just log to the console.

    // Only log Warnings or more important.
    if (severity <= Severity::kWARNING)
        std::cout << msg << std::endl;

bool Engine::doesFileExist(const std::string &filepath)
    std::ifstream f(filepath.c_str());
    return f.good();

Engine::Engine(const Options &options)
    : m_options(options) {}

bool Engine::build(std::string onnxModelPath)
    // Only regenerate the engine file if it has not already been generated for the specified options
    m_engineName = serializeEngineOptions(m_options);
    std::cout << "Searching for engine file with name: " << m_engineName << std::endl;

    if (doesFileExist(m_engineName))
        std::cout << "Engine found, not regenerating..." << std::endl;
        return true;

    // Was not able to find the engine file, generate...
    std::cout << "Engine not found, generating..." << std::endl;

    // Create our engine builder.
    auto builder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(m_logger));
    if (!builder)
        return false;

    // Set the max supported batch size

    // Define an explicit batch size and then create the network.
    // More info here:
    auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    auto network = std::unique_ptr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
    if (!network)
        return false;

    // Create a parser for reading the onnx file.
    auto parser = std::unique_ptr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, m_logger));
    if (!parser)
        return false;

    // We are going to first read the onnx file into memory, then pass that buffer to the parser.
    // Had our onnx model file been encrypted, this approach would allow us to first decrypt the buffer.

    std::ifstream file(onnxModelPath, std::ios::binary | std::ios::ate);
    std::streamsize size = file.tellg();
    file.seekg(0, std::ios::beg);

    std::vector<char> buffer(size);
    if (!, size))
        throw std::runtime_error("Unable to read engine file");

    auto parsed = parser->parse(, buffer.size());
    if (!parsed)
        return false;

    // Save the input height, width, and channels.
    // Require this info for inference.
    const auto input = network->getInput(0);
    const auto inputName = input->getName();

    auto config = std::unique_ptr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
    if (!config)
        return false;

    // Specify the optimization profiles and the
    IOptimizationProfile *defaultProfile = builder->createOptimizationProfile();
    defaultProfile->setDimensions(inputName, OptProfileSelector::kMIN, Dims4(1, m_options.inputDimension[0], m_options.inputDimension[1], m_options.inputDimension[2]));
    defaultProfile->setDimensions(inputName, OptProfileSelector::kOPT, Dims4(1, m_options.inputDimension[0], m_options.inputDimension[1], m_options.inputDimension[2]));
    defaultProfile->setDimensions(inputName, OptProfileSelector::kMAX, Dims4(1, m_options.inputDimension[0], m_options.inputDimension[1], m_options.inputDimension[2]));



    // Disable: cublas for PPOCRv3. You can comment this line if you are using PPOCRv2.
    // Reference:

    if (m_options.FP16)

    // CUDA stream used for profiling by the builder.
    auto profileStream = samplesCommon::makeCudaStream();
    if (!profileStream)
        return false;

    // Build the engine
    std::unique_ptr<IHostMemory> plan{builder->buildSerializedNetwork(*network, *config)};
    if (!plan)
        return false;

    // Write the engine to disk
    std::ofstream outfile(m_engineName, std::ofstream::binary);
    outfile.write(reinterpret_cast<const char *>(plan->data()), plan->size());

    std::cout << "Success, saved engine to " << m_engineName << std::endl;

    return true;

    if (m_cudaStream)

bool Engine::loadNetwork()
    // Read the serialized model from disk
    std::ifstream file(m_engineName, std::ios::binary | std::ios::ate);
    std::streamsize size = file.tellg();
    file.seekg(0, std::ios::beg);

    std::vector<char> buffer(size);
    if (!, size))
        throw std::runtime_error("Unable to read engine file");

    std::unique_ptr<IRuntime> runtime{createInferRuntime(m_logger)};
    if (!runtime)
        return false;

    // Set the device index
    auto ret = cudaSetDevice(m_options.deviceIndex);
    if (ret != 0)
        int numGPUs;
        auto errMsg = "Unable to set GPU device index to: " + std::to_string(m_options.deviceIndex) +
                      ". Note, your device has " + std::to_string(numGPUs) + " CUDA-capable GPU(s).";
        throw std::runtime_error(errMsg);

    m_engine = std::unique_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(, buffer.size()));
    if (!m_engine)
        return false;

    m_context = std::unique_ptr<nvinfer1::IExecutionContext>(m_engine->createExecutionContext());
    if (!m_context)
        return false;

    auto cudaRet = cudaStreamCreate(&m_cudaStream);
    if (cudaRet != 0)
        throw std::runtime_error("Unable to create cuda stream");

    return true;

int Engine::runInference(const std::vector<cv::Mat> &inputFaceChips, std::vector<std::vector<float>> &featureVectors)
    // If uninitialized, set the input and output dimenison of buffer.
    if (if_initialize == 0)
        // Set the input dimension
        auto dims = m_context->getBindingDimensions(0);
        dims.d[0] = inputFaceChips.size();
        dims.d[1] = m_options.inputDimension[0];
        dims.d[2] = m_options.inputDimension[1];
        dims.d[3] = m_options.inputDimension[2];
        Dims4 inputDims{dims.d[0], dims.d[1], dims.d[2], dims.d[3]};

        // Get the output dimension
        m_context->setBindingDimensions(0, dims);
        outputDims = m_context->getBindingDimensions(1);

        if (!m_context->allInputDimensionsSpecified())
            throw std::runtime_error("Error, not all input dimensions specified.");



        if_initialize = 1;

    auto *hostDataBuffer = static_cast<float *>(;

    for (size_t batch = 0; batch < inputFaceChips.size(); ++batch)
        auto image = inputFaceChips[batch];

        // NHWC to NCHW conversion
        // NHWC: For each pixel, its 3 colors are stored together in RGB order.
        // For a 3 channel image, say RGB, pixels of the R channel are stored first, then the G channel and finally the B channel.
        int offset = 3 * inputFaceChips[0].rows * inputFaceChips[0].cols * batch;
        int r = 0, g = 0, b = 0;
        for (int i = 0; i < 3 * inputFaceChips[0].rows * inputFaceChips[0].cols; ++i)
            if (i % 3 == 0)
                hostDataBuffer[offset + r++] = *(reinterpret_cast<float *>( + i);
            else if (i % 3 == 1)
                hostDataBuffer[offset + g++ + inputFaceChips[0].rows * inputFaceChips[0].cols] = *(reinterpret_cast<float *>( + i);
                hostDataBuffer[offset + b++ + inputFaceChips[0].rows * inputFaceChips[0].cols * 2] = *(reinterpret_cast<float *>( + i);

    // Copy from CPU to GPU
    auto ret = cudaMemcpyAsync(,, m_inputBuff.hostBuffer.nbBytes(), cudaMemcpyHostToDevice, m_cudaStream);
    if (ret != 0)
        return false;

    std::vector<void *> predicitonBindings = {,};

    // Run inference
    bool status = m_context->enqueueV2(, m_cudaStream, nullptr);
    if (!status)
        return false;

    // Copy the results back to CPU memory
    ret = cudaMemcpyAsync(,, m_outputBuff.deviceBuffer.nbBytes(), cudaMemcpyDeviceToHost, m_cudaStream);
    if (ret != 0)
        std::cout << "Unable to copy buffer from GPU back to CPU" << std::endl;
        return false;

    ret = cudaStreamSynchronize(m_cudaStream);
    if (ret != 0)
        std::cout << "Unable to synchronize cuda stream" << std::endl;
        return false;

    // Copy to output
    int outsize = outputDims.d[1] * outputDims.d[2];

    std::vector<float> featureVector;
    memcpy(, reinterpret_cast<const char *>(, outsize * sizeof(float));

    return outsize;

std::string Engine::serializeEngineOptions(const Options &options)
    std::string engineName = "trt.engine";

    std::vector<std::string> gpuUUIDs;

    if (static_cast<size_t>(options.deviceIndex) >= gpuUUIDs.size())
        throw std::runtime_error("Error, provided device index is out of range!");

    engineName += "." + gpuUUIDs[options.deviceIndex];

    // Serialize the specified options into the filename
    if (options.FP16)
        engineName += ".fp16";
        engineName += ".fp32";

    engineName += "." + std::to_string(options.maxBatchSize) + ".";
    for (size_t i = 0; i < m_options.optBatchSizes.size(); ++i)
        engineName += std::to_string(m_options.optBatchSizes[i]);
        if (i != m_options.optBatchSizes.size() - 1)
            engineName += "_";

    engineName += "." + std::to_string(options.maxWorkspaceSize);

    return engineName;

void Engine::getGPUUUIDs(std::vector<std::string> &gpuUUIDs)
    int numGPUs;

    for (int device = 0; device < numGPUs; device++)
        cudaDeviceProp prop;
        cudaGetDeviceProperties(&prop, device);

        char uuid[33];
        for (int b = 0; b < 16; b++)
            sprintf(&uuid[b * 2], "%02x", (unsigned char)prop.uuid.bytes[b]);

        // by comparing uuid against a preset list of valid uuids given by the client (using: nvidia-smi -L) we decide which gpus can be used.

cv::Mat Engine::preprocessImg(const std::string inputImage)
    auto img = cv::imread(inputImage, -1);
    cv::Mat resize_img;

    if (img.cols * 1.0 / img.rows * m_options.inputDimension[1] >= m_options.inputDimension[2])
        cv::resize(img, resize_img, cv::Size(m_options.inputDimension[2], m_options.inputDimension[1]));
        cv::resize(img, resize_img, cv::Size(int(img.cols * 1.0 / img.rows * m_options.inputDimension[1] + 1), m_options.inputDimension[1]), 0.f, 0.f,
        // Padding the right part
        cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, m_options.inputDimension[2] - int(img.cols * 1.0 / img.rows * m_options.inputDimension[1] + 1), cv::BORDER_CONSTANT, {0, 0, 0});

    return resize_img;

std::vector<std::string> ReadDict(const std::string &path)
    std::ifstream in(path);
    std::string line;
    std::vector<std::string> m_vec;
    if (in)
        while (getline(in, line))
        std::cout << "no such label file: " << path << ", exit the program..."
                  << std::endl;
    return m_vec;


#pragma once

#include <opencv2/opencv.hpp>

#include "NvInfer.h"
#include "buffers.h"

// Options for the network
struct Options
    // Use 16 bit floating point type for inference
    bool FP16 = false;
    // Batch sizes to optimize for.
    std::vector<int32_t> optBatchSizes = {1};
    // Maximum allowable batch size
    int32_t maxBatchSize = 1;
    // Max allowable GPU memory to be used for model conversion, in bytes.
    // Applications should allow the engine builder as much workspace as they can afford;
    // at runtime, the SDK allocates no more than this and typically less.
    size_t maxWorkspaceSize = 40000000000;
    // GPU device index
    int deviceIndex = 0;
    // Input dimension CHW
    std::vector<int> inputDimension = {3, 32, 320};

// Class to extend TensorRT logger
class Logger : public nvinfer1::ILogger
    void log(Severity severity, const char *msg) noexcept override;

class Engine
    Engine(const Options &options);
    // Build the network
    bool build(std::string onnxModelPath);
    // Load and prepare the network for inference
    bool loadNetwork();
    // Preprocess
    cv::Mat preprocessImg(const std::string);
    // Run inference.
    int runInference(const std::vector<cv::Mat> &inputFaceChips, std::vector<std::vector<float>> &featureVectors);

    nvinfer1::Dims outputDims;

    // Converts the engine options into a string
    std::string serializeEngineOptions(const Options &options);

    void getGPUUUIDs(std::vector<std::string> &gpuUUIDs);

    bool doesFileExist(const std::string &filepath);

    std::unique_ptr<nvinfer1::ICudaEngine> m_engine = nullptr;
    std::unique_ptr<nvinfer1::IExecutionContext> m_context = nullptr;
    const Options &m_options;
    Logger m_logger;
    samplesCommon::ManagedBuffer m_inputBuff;
    samplesCommon::ManagedBuffer m_outputBuff;
    bool if_initialize = 0;
    std::string m_engineName;
    cudaStream_t m_cudaStream = nullptr;

void Normalize(cv::Mat *im, const std::vector<float> &mean,
               const std::vector<float> &scale, const bool is_scale);

std::vector<std::string> ReadDict(const std::string &path);


#include "NvInferRuntimeCommon.h"
#include "logger.h"
#include <atomic>
#include <cstdint>
#include <exception>
#include <mutex>
#include <vector>

using nvinfer1::IErrorRecorder;
using nvinfer1::ErrorCode;

//! A simple implementation of the IErrorRecorder interface for
//! use by samples. This interface also can be used as a reference
//! implementation.
//! The sample Error recorder is based on a vector that pairs the error
//! code and the error string into a single element. It also uses
//! standard mutex's and atomics in order to make sure that the code
//! works in a multi-threaded environment.
class SampleErrorRecorder : public IErrorRecorder
    using errorPair = std::pair<ErrorCode, std::string>;
    using errorStack = std::vector<errorPair>;

    SampleErrorRecorder() = default;

    virtual ~SampleErrorRecorder() noexcept {}
    int32_t getNbErrors() const noexcept final
        return mErrorStack.size();
    ErrorCode getErrorCode(int32_t errorIdx) const noexcept final
        return invalidIndexCheck(errorIdx) ? ErrorCode::kINVALID_ARGUMENT : (*this)[errorIdx].first;
    IErrorRecorder::ErrorDesc getErrorDesc(int32_t errorIdx) const noexcept final
        return invalidIndexCheck(errorIdx) ? "errorIdx out of range." : (*this)[errorIdx].second.c_str();
    // This class can never overflow since we have dynamic resize via std::vector usage.
    bool hasOverflowed() const noexcept final
        return false;

    // Empty the errorStack.
    void clear() noexcept final
            // grab a lock so that there is no addition while clearing.
            std::lock_guard<std::mutex> guard(mStackLock);
        catch (const std::exception& e)
            sample::gLogFatal << "Internal Error: " << e.what() << std::endl;

    //! Simple helper function that
    bool empty() const noexcept
        return mErrorStack.empty();

    bool reportError(ErrorCode val, IErrorRecorder::ErrorDesc desc) noexcept final
            std::lock_guard<std::mutex> guard(mStackLock);
            sample::gLogError << "Error[" << static_cast<int32_t>(val) << "]: " << desc << std::endl;
            mErrorStack.push_back(errorPair(val, desc));
        catch (const std::exception& e)
            sample::gLogFatal << "Internal Error: " << e.what() << std::endl;
        // All errors are considered fatal.
        return true;

    // Atomically increment or decrement the ref counter.
    IErrorRecorder::RefCount incRefCount() noexcept final
        return ++mRefCount;
    IErrorRecorder::RefCount decRefCount() noexcept final
        return --mRefCount;

    // Simple helper functions.
    const errorPair& operator[](size_t index) const noexcept
        return mErrorStack[index];

    bool invalidIndexCheck(int32_t index) const noexcept
        // By converting signed to unsigned, we only need a single check since
        // negative numbers turn into large positive greater than the size.
        size_t sIndex = index;
        return sIndex >= mErrorStack.size();
    // Mutex to hold when locking mErrorStack.
    std::mutex mStackLock;

    // Reference count of the class. Destruction of the class when mRefCount
    // is not zero causes undefined behavior.
    std::atomic<int32_t> mRefCount{0};

    // The error stack that holds the errors recorded by TensorRT.
    errorStack mErrorStack;
};     // class SampleErrorRecorder


half - IEEE 754-based half-precision floating point library.
// Copyright (c) 2012-2017 Christian Rau <>
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
// documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to the following conditions:
// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
// Software.

// Version 1.12.0

/// \file
/// Main header file for half precision functionality.


/// Combined gcc version number.
#define HALF_GNUC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__)

// check C++11 language features
#if defined(__clang__) // clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG)
/*#elif defined(__INTEL_COMPILER)								//Intel C++
    #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)		????????
    #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)			????????
        #define HALF_ENABLE_CPP11_CONSTEXPR 1
    #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)			????????
        #define HALF_ENABLE_CPP11_NOEXCEPT 1
    #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG)			????????
        #define HALF_ENABLE_CPP11_LONG_LONG 1
#elif defined(__GNUC__) // gcc
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#elif defined(_MSC_VER) // Visual C++
#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#pragma warning(push)
#pragma warning(disable : 4099 4127 4146) // struct vs class, constant in if, negative unsigned

// check C++11 library features
#include <utility>
#if defined(_LIBCPP_VERSION) // libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#elif defined(__GLIBCXX__) // libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++
#if _CPPLIB_VER >= 520
#if _CPPLIB_VER >= 610

// support constexpr
#define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr

// support noexcept
#define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept
#define HALF_NOTHROW throw()

#include <algorithm>
#include <climits>
#include <cmath>
#include <cstring>
#include <iostream>
#include <limits>
#include <type_traits>
#include <cstdint>
#include <functional>

/// Default rounding mode.
/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as
/// well as for the half_cast() if not specifying a rounding mode explicitly. It can be redefined (before including
/// half.hpp) to one of the standard rounding modes using their respective constants or the equivalent values of
/// `std::float_round_style`:
/// `std::float_round_style`         | value | rounding
/// ---------------------------------|-------|-------------------------
/// `std::round_indeterminate`       | -1    | fastest (default)
/// `std::round_toward_zero`         | 0     | toward zero
/// `std::round_to_nearest`          | 1     | to nearest
/// `std::round_toward_infinity`     | 2     | toward positive infinity
/// `std::round_toward_neg_infinity` | 3     | toward negative infinity
/// By default this is set to `-1` (`std::round_indeterminate`), which uses truncation (round toward zero, but with
/// overflows set to infinity) and is the fastest rounding mode possible. It can even be set to
/// `std::numeric_limits<float>::round_style` to synchronize the rounding mode with that of the underlying
/// single-precision implementation.
#define HALF_ROUND_STYLE 1 // = std::round_to_nearest

/// Tie-breaking behaviour for round to nearest.
/// This specifies if ties in round to nearest should be resolved by rounding to the nearest even value. By default this
/// is defined to `0` resulting in the faster but slightly more biased behaviour of rounding away from zero in half-way
/// cases (and thus equal to the round() function), but can be redefined to `1` (before including half.hpp) if more
/// IEEE-conformant behaviour is needed.
#define HALF_ROUND_TIES_TO_EVEN 0 // ties away from zero

/// Value signaling overflow.
/// In correspondence with `HUGE_VAL[F|L]` from `<cmath>` this symbol expands to a positive value signaling the overflow
/// of an operation, in particular it just evaluates to positive infinity.
#define HUGE_VALH std::numeric_limits<half_float::half>::infinity()

/// Fast half-precision fma function.
/// This symbol is only defined if the fma() function generally executes as fast as, or faster than, a separate
/// half-precision multiplication followed by an addition. Due to the internal single-precision implementation of all
/// arithmetic operations, this is in fact always the case.
#define FP_FAST_FMAH 1

#ifndef FP_ILOGB0
#define FP_SUBNORMAL 0
#ifndef FP_ZERO
#define FP_ZERO 1
#ifndef FP_NAN
#define FP_NAN 2
#define FP_INFINITE 3
#ifndef FP_NORMAL
#define FP_NORMAL 4

/// Main namespace for half precision functionality.
/// This namespace contains all the functionality provided by the library.
namespace half_float
class half;

/// Library-defined half-precision literals.
/// Import this namespace to enable half-precision floating point literals:
/// ~~~~{.cpp}
/// using namespace half_float::literal;
/// half_float::half = 4.2_h;
/// ~~~~
namespace literal
half operator"" _h(long double);

/// \internal
/// \brief Implementation details.
namespace detail
/// Conditional type.
template <bool B, typename T, typename F>
struct conditional : std::conditional<B, T, F>

/// Helper for tag dispatching.
template <bool B>
struct bool_type : std::integral_constant<bool, B>
using std::false_type;
using std::true_type;

/// Type traits for floating point types.
template <typename T>
struct is_float : std::is_floating_point<T>
/// Conditional type.
template <bool, typename T, typename>
struct conditional
    typedef T type;
template <typename T, typename F>
struct conditional<false, T, F>
    typedef F type;

/// Helper for tag dispatching.
template <bool>
struct bool_type
typedef bool_type<true> true_type;
typedef bool_type<false> false_type;

/// Type traits for floating point types.
template <typename>
struct is_float : false_type
template <typename T>
struct is_float<const T> : is_float<T>
template <typename T>
struct is_float<volatile T> : is_float<T>
template <typename T>
struct is_float<const volatile T> : is_float<T>
template <>
struct is_float<float> : true_type
template <>
struct is_float<double> : true_type
template <>
struct is_float<long double> : true_type

/// Type traits for floating point bits.
template <typename T>
struct bits
    typedef unsigned char type;
template <typename T>
struct bits<const T> : bits<T>
template <typename T>
struct bits<volatile T> : bits<T>
template <typename T>
struct bits<const volatile T> : bits<T>

/// Unsigned integer of (at least) 16 bits width.
typedef std::uint_least16_t uint16;

/// Unsigned integer of (at least) 32 bits width.
template <>
struct bits<float>
    typedef std::uint_least32_t type;

/// Unsigned integer of (at least) 64 bits width.
template <>
struct bits<double>
    typedef std::uint_least64_t type;
/// Unsigned integer of (at least) 16 bits width.
typedef unsigned short uint16;

/// Unsigned integer of (at least) 32 bits width.
template <>
struct bits<float> : conditional<std::numeric_limits<unsigned int>::digits >= 32, unsigned int, unsigned long>

/// Unsigned integer of (at least) 64 bits width.
template <>
struct bits<double> : conditional<std::numeric_limits<unsigned long>::digits >= 64, unsigned long, unsigned long long>
/// Unsigned integer of (at least) 64 bits width.
template <>
struct bits<double>
    typedef unsigned long type;

/// Tag type for binary construction.
struct binary_t

/// Tag for binary construction.
HALF_CONSTEXPR_CONST binary_t binary = binary_t();

/// Temporary half-precision expression.
/// This class represents a half-precision expression which just stores a single-precision value internally.
struct expr
    /// Conversion constructor.
    /// \param f single-precision value to convert
    explicit HALF_CONSTEXPR expr(float f) HALF_NOEXCEPT : value_(f) {}

    /// Conversion to single-precision.
    /// \return single precision value representing expression value
    HALF_CONSTEXPR operator float() const HALF_NOEXCEPT
        return value_;

    /// Internal expression value stored in single-precision.
    float value_;

/// SFINAE helper for generic half-precision functions.
/// This class template has to be specialized for each valid combination of argument types to provide a corresponding
/// `type` member equivalent to \a T.
/// \tparam T type to return
template <typename T, typename, typename = void, typename = void>
struct enable
template <typename T>
struct enable<T, half, void, void>
    typedef T type;
template <typename T>
struct enable<T, expr, void, void>
    typedef T type;
template <typename T>
struct enable<T, half, half, void>
    typedef T type;
template <typename T>
struct enable<T, half, expr, void>
    typedef T type;
template <typename T>
struct enable<T, expr, half, void>
    typedef T type;
template <typename T>
struct enable<T, expr, expr, void>
    typedef T type;
template <typename T>
struct enable<T, half, half, half>
    typedef T type;
template <typename T>
struct enable<T, half, half, expr>
    typedef T type;
template <typename T>
struct enable<T, half, expr, half>
    typedef T type;
template <typename T>
struct enable<T, half, expr, expr>
    typedef T type;
template <typename T>
struct enable<T, expr, half, half>
    typedef T type;
template <typename T>
struct enable<T, expr, half, expr>
    typedef T type;
template <typename T>
struct enable<T, expr, expr, half>
    typedef T type;
template <typename T>
struct enable<T, expr, expr, expr>
    typedef T type;

/// Return type for specialized generic 2-argument half-precision functions.
/// This class template has to be specialized for each valid combination of argument types to provide a corresponding
/// `type` member denoting the appropriate return type.
/// \tparam T first argument type
/// \tparam U first argument type
template <typename T, typename U>
struct result : enable<expr, T, U>
template <>
struct result<half, half>
    typedef half type;

/// \name Classification helpers
/// \{

/// Check for infinity.
/// \tparam T argument type (builtin floating point type)
/// \param arg value to query
/// \retval true if infinity
/// \retval false else
template <typename T>
bool builtin_isinf(T arg)
    return std::isinf(arg);
#elif defined(_MSC_VER)
    return !::_finite(static_cast<double>(arg)) && !::_isnan(static_cast<double>(arg));
    return arg == std::numeric_limits<T>::infinity() || arg == -std::numeric_limits<T>::infinity();

/// Check for NaN.
/// \tparam T argument type (builtin floating point type)
/// \param arg value to query
/// \retval true if not a number
/// \retval false else
template <typename T>
bool builtin_isnan(T arg)
    return std::isnan(arg);
#elif defined(_MSC_VER)
    return ::_isnan(static_cast<double>(arg)) != 0;
    return arg != arg;

/// Check sign.
/// \tparam T argument type (builtin floating point type)
/// \param arg value to query
/// \retval true if signbit set
/// \retval false else
template <typename T>
bool builtin_signbit(T arg)
    return std::signbit(arg);
    return arg < T() || (arg == T() && T(1) / arg < T());

/// \}
/// \name Conversion
/// \{

/// Convert IEEE single-precision to half-precision.
/// Credit for this goes to [Jeroen van der Zijp](
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
/// \param value single-precision value
/// \return binary representation of half-precision value
template <std::float_round_style R>
uint16 float2half_impl(float value, true_type)
    typedef bits<float>::type uint32;
    uint32 bits; // = *reinterpret_cast<uint32*>(&value);		//violating strict aliasing!
    std::memcpy(&bits, &value, sizeof(float));
    /*			uint16 hbits = (bits>>16) & 0x8000;
                bits &= 0x7FFFFFFF;
                int exp = bits >> 23;
                if(exp == 255)
                    return hbits | 0x7C00 | (0x3FF&-static_cast<unsigned>((bits&0x7FFFFF)!=0));
                if(exp > 142)
                    if(R == std::round_toward_infinity)
                        return hbits | 0x7C00 - (hbits>>15);
                    if(R == std::round_toward_neg_infinity)
                        return hbits | 0x7BFF + (hbits>>15);
                    return hbits | 0x7BFF + (R!=std::round_toward_zero);
                int g, s;
                if(exp > 112)
                    g = (bits>>12) & 1;
                    s = (bits&0xFFF) != 0;
                    hbits |= ((exp-112)<<10) | ((bits>>13)&0x3FF);
                else if(exp > 101)
                    int i = 125 - exp;
                    bits = (bits&0x7FFFFF) | 0x800000;
                    g = (bits>>i) & 1;
                    s = (bits&((1L<<i)-1)) != 0;
                    hbits |= bits >> (i+1);
                    g = 0;
                    s = bits != 0;
                if(R == std::round_to_nearest)
                    #if HALF_ROUND_TIES_TO_EVEN
                        hbits += g & (s|hbits);
                        hbits += g;
                else if(R == std::round_toward_infinity)
                    hbits += ~(hbits>>15) & (s|g);
                else if(R == std::round_toward_neg_infinity)
                    hbits += (hbits>>15) & (g|s);
    uint32 bits = mantissa_table[offset_table[value >> 10] + (value & 0x3FF)] + exponent_table[value >> 10];
    //			return *reinterpret_cast<float*>(&bits);			//violating strict aliasing!
    float out;
    std::memcpy(&out, &bits, sizeof(float));
    return out;

/// Convert half-precision to IEEE double-precision.
/// \param value binary representation of half-precision value
/// \return double-precision value
inline double half2float_impl(uint16 value, double, true_type)
    typedef bits<float>::type uint32;
    typedef bits<double>::type uint64;
    uint32 hi = static_cast<uint32>(value & 0x8000) << 16;
    int abs = value & 0x7FFF;
    if (abs)
        hi |= 0x3F000000 << static_cast<unsigned>(abs >= 0x7C00);
        for (; abs < 0x400; abs <<= 1, hi -= 0x100000)
        hi += static_cast<uint32>(abs) << 10;
    uint64 bits = static_cast<uint64>(hi) << 32;
    //			return *reinterpret_cast<double*>(&bits);			//violating strict aliasing!
    double out;
    std::memcpy(&out, &bits, sizeof(double));
    return out;

/// Convert half-precision to non-IEEE floating point.
/// \tparam T type to convert to (builtin integer type)
/// \param value binary representation of half-precision value
/// \return floating point value
template <typename T>
T half2float_impl(uint16 value, T, ...)
    T out;
    int abs = value & 0x7FFF;
    if (abs > 0x7C00)
        out = std::numeric_limits<T>::has_quiet_NaN ? std::numeric_limits<T>::quiet_NaN() : T();
    else if (abs == 0x7C00)
        out = std::numeric_limits<T>::has_infinity ? std::numeric_limits<T>::infinity() : std::numeric_limits<T>::max();
    else if (abs > 0x3FF)
        out = std::ldexp(static_cast<T>((abs & 0x3FF) | 0x400), (abs >> 10) - 25);
        out = std::ldexp(static_cast<T>(abs), -24);
    return (value & 0x8000) ? -out : out;

/// Convert half-precision to floating point.
/// \tparam T type to convert to (builtin integer type)
/// \param value binary representation of half-precision value
/// \return floating point value
template <typename T>
T half2float(uint16 value)
    return half2float_impl(
        value, T(), bool_type < std::numeric_limits<T>::is_iec559 && sizeof(typename bits<T>::type) == sizeof(T) > ());

/// Convert half-precision floating point to integer.
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
/// \tparam E `true` for round to even, `false` for round away from zero
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign
/// bits) \param value binary representation of half-precision value \return integral value
template <std::float_round_style R, bool E, typename T>
T half2int_impl(uint16 value)
    static_assert(std::is_integral<T>::value, "half to int conversion only supports builtin integer types");
    uint32_t e = value & 0x7FFF;
    if (e >= 0x7C00)
        return (value & 0x8000) ? std::numeric_limits<T>::min() : std::numeric_limits<T>::max();
    if (e < 0x3800)
        if (R == std::round_toward_infinity)
            return T(~(value >> 15) & (e != 0));
        else if (R == std::round_toward_neg_infinity)
            return -T(value > 0x8000);
        return T();
    uint32_t m = (value & 0x3FF) | 0x400;
    e >>= 10;
    if (e < 25)
        if (R == std::round_to_nearest)
            m += (1 << (24 - e)) - (~(m >> (25 - e)) & E);
        else if (R == std::round_toward_infinity)
            m += ((value >> 15) - 1) & ((1 << (25 - e)) - 1U);
        else if (R == std::round_toward_neg_infinity)
            m += -(value >> 15) & ((1 << (25 - e)) - 1U);
        m >>= 25 - e;
        m <<= e - 25;
    return (value & 0x8000) ? -static_cast<T>(m) : static_cast<T>(m);

/// Convert half-precision floating point to integer.
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign
/// bits) \param value binary representation of half-precision value \return integral value
template <std::float_round_style R, typename T>
T half2int(uint16 value)
    return half2int_impl<R, HALF_ROUND_TIES_TO_EVEN, T>(value);

/// Convert half-precision floating point to integer using round-to-nearest-away-from-zero.
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign
/// bits) \param value binary representation of half-precision value \return integral value
template <typename T>
T half2int_up(uint16 value)
    return half2int_impl<std::round_to_nearest, 0, T>(value);

/// Round half-precision number to nearest integer value.
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
/// \tparam E `true` for round to even, `false` for round away from zero
/// \param value binary representation of half-precision value
/// \return half-precision bits for nearest integral value
template <std::float_round_style R, bool E>
uint16 round_half_impl(uint16 value)
    uint32_t e = value & 0x7FFF;
    uint16 result = value;
    if (e < 0x3C00)
        result &= 0x8000;
        if (R == std::round_to_nearest)
            result |= 0x3C00U & -(e >= (0x3800 + E));
        else if (R == std::round_toward_infinity)
            result |= 0x3C00U & -(~(value >> 15) & (e != 0));
        else if (R == std::round_toward_neg_infinity)
            result |= 0x3C00U & -(value > 0x8000);
    else if (e < 0x6400)
        e = 25 - (e >> 10);
        uint32_t mask = (1 << e) - 1;
        if (R == std::round_to_nearest)
            result += (1 << (e - 1)) - (~(result >> e) & E);
        else if (R == std::round_toward_infinity)
            result += mask & ((value >> 15) - 1);
        else if (R == std::round_toward_neg_infinity)
            result += mask & -(value >> 15);
        result &= ~mask;
    return result;

/// Round half-precision number to nearest integer value.
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
/// \param value binary representation of half-precision value
/// \return half-precision bits for nearest integral value
template <std::float_round_style R>
uint16 round_half(uint16 value)
    return round_half_impl<R, HALF_ROUND_TIES_TO_EVEN>(value);

/// Round half-precision number to nearest integer value using round-to-nearest-away-from-zero.
/// \param value binary representation of half-precision value
/// \return half-precision bits for nearest integral value
inline uint16 round_half_up(uint16 value)
    return round_half_impl<std::round_to_nearest, 0>(value);
/// \}

struct functions;
template <typename>
struct unary_specialized;
template <typename, typename>
struct binary_specialized;
template <typename, typename, std::float_round_style>
struct half_caster;
} // namespace detail

/// Half-precision floating point type.
/// This class implements an IEEE-conformant half-precision floating point type with the usual arithmetic operators and
/// conversions. It is implicitly convertible to single-precision floating point, which makes artihmetic expressions and
/// functions with mixed-type operands to be of the most precise operand type. Additionally all arithmetic operations
/// (and many mathematical functions) are carried out in single-precision internally. All conversions from single- to
/// half-precision are done using the library's default rounding mode, but temporary results inside chained arithmetic
/// expressions are kept in single-precision as long as possible (while of course still maintaining a strong
/// half-precision type).
/// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's less strict and
/// extended definitions it is both a standard layout type and a trivially copyable type (even if not a POD type), which
/// means it can be standard-conformantly copied using raw binary copies. But in this context some more words about the
/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not neccessarily have to be
/// of exactly 16-bits size. But on any reasonable implementation the actual binary representation of this type will
/// most probably not ivolve any additional "magic" or padding beyond the simple binary representation of the underlying
/// 16-bit IEEE number, even if not strictly guaranteed by the standard. But even then it only has an actual size of 16
/// bits if your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this should be the
/// case on nearly any reasonable platform.
/// So if your C++ implementation is not totally exotic or imposes special alignment requirements, it is a reasonable
/// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE representation.
class half
    friend struct detail::functions;
    friend struct detail::unary_specialized<half>;
    friend struct detail::binary_specialized<half, half>;
    template <typename, typename, std::float_round_style>
    friend struct detail::half_caster;
    friend class std::numeric_limits<half>;
    friend struct std::hash<half>;
    friend half literal::operator"" _h(long double);

    /// Default constructor.
    /// This initializes the half to 0. Although this does not match the builtin types' default-initialization semantics
    /// and may be less efficient than no initialization, it is needed to provide proper value-initialization semantics.
    HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {}

    /// Copy constructor.
    /// \tparam T type of concrete half expression
    /// \param rhs half expression to copy from
    half(detail::expr rhs)
        : data_(detail::float2half<round_style>(static_cast<float>(rhs)))

    /// Conversion constructor.
    /// \param rhs float to convert
    explicit half(float rhs)
        : data_(detail::float2half<round_style>(rhs))

    /// Conversion to single-precision.
    /// \return single precision value representing expression value
    operator float() const
        return detail::half2float<float>(data_);

    /// Assignment operator.
    /// \tparam T type of concrete half expression
    /// \param rhs half expression to copy from
    /// \return reference to this half
    half& operator=(detail::expr rhs)
        return *this = static_cast<float>(rhs);

    /// Arithmetic assignment.
    /// \tparam T type of concrete half expression
    /// \param rhs half expression to add
    /// \return reference to this half
    template <typename T>
    typename detail::enable<half&, T>::type operator+=(T rhs)
        return *this += static_cast<float>(rhs);

    /// Arithmetic assignment.
    /// \tparam T type of concrete half expression
    /// \param rhs half expression to subtract
    /// \return reference to this half
    template <typename T>
    typename detail::enable<half&, T>::type operator-=(T rhs)
        return *this -= static_cast<float>(rhs);

    /// Arithmetic assignment.
    /// \tparam T type of concrete half expression
    /// \param rhs half expression to multiply with
    /// \return reference to this half
    template <typename T>
    typename detail::enable<half&, T>::type operator*=(T rhs)
        return *this *= static_cast<float>(rhs);

    /// Arithmetic assignment.
    /// \tparam T type of concrete half expression
    /// \param rhs half expression to divide by
    /// \return reference to this half
    template <typename T>
    typename detail::enable<half&, T>::type operator/=(T rhs)
        return *this /= static_cast<float>(rhs);

    /// Assignment operator.
    /// \param rhs single-precision value to copy from
    /// \return reference to this half
    half& operator=(float rhs)
        data_ = detail::float2half<round_style>(rhs);
        return *this;

    /// Arithmetic assignment.
    /// \param rhs single-precision value to add
    /// \return reference to this half
    half& operator+=(float rhs)
        data_ = detail::float2half<round_style>(detail::half2float<float>(data_) + rhs);
        return *this;

    /// Arithmetic assignment.
    /// \param rhs single-precision value to subtract
    /// \return reference to this half
    half& operator-=(float rhs)
        data_ = detail::float2half<round_style>(detail::half2float<float>(data_) - rhs);
        return *this;

    /// Arithmetic assignment.
    /// \param rhs single-precision value to multiply with
    /// \return reference to this half
    half& operator*=(float rhs)
        data_ = detail::float2half<round_style>(detail::half2float<float>(data_) * rhs);
        return *this;

    /// Arithmetic assignment.
    /// \param rhs single-precision value to divide by
    /// \return reference to this half
    half& operator/=(float rhs)
        data_ = detail::float2half<round_style>(detail::half2float<float>(data_) / rhs);
        return *this;

    /// Prefix increment.
    /// \return incremented half value
    half& operator++()
        return *this += 1.0f;

    /// Prefix decrement.
    /// \return decremented half value
    half& operator--()
        return *this -= 1.0f;

    /// Postfix increment.
    /// \return non-incremented half value
    half operator++(int)
        half out(*this);
        return out;

    /// Postfix decrement.
    /// \return non-decremented half value
    half operator--(int)
        half out(*this);
        return out;

    /// Rounding mode to use
    static const std::float_round_style round_style = (std::float_round_style)(HALF_ROUND_STYLE);

    /// Constructor.
    /// \param bits binary representation to set half to
    HALF_CONSTEXPR half(detail::binary_t, detail::uint16 bits) HALF_NOEXCEPT : data_(bits) {}

    /// Internal binary representation
    detail::uint16 data_;

namespace literal
/// Half literal.
/// While this returns an actual half-precision value, half literals can unfortunately not be constant expressions due
/// to rather involved conversions.
/// \param value literal value
/// \return half with given value (if representable)
inline half operator"" _h(long double value)
    return half(detail::binary, detail::float2half<half::round_style>(value));
} // namespace literal

namespace detail
/// Wrapper implementing unspecialized half-precision functions.
struct functions
    /// Addition implementation.
    /// \param x first operand
    /// \param y second operand
    /// \return Half-precision sum stored in single-precision
    static expr plus(float x, float y)
        return expr(x + y);

    /// Subtraction implementation.
    /// \param x first operand
    /// \param y second operand
    /// \return Half-precision difference stored in single-precision
    static expr minus(float x, float y)
        return expr(x - y);

    /// Multiplication implementation.
    /// \param x first operand
    /// \param y second operand
    /// \return Half-precision product stored in single-precision
    static expr multiplies(float x, float y)
        return expr(x * y);

    /// Division implementation.
    /// \param x first operand
    /// \param y second operand
    /// \return Half-precision quotient stored in single-precision
    static expr divides(float x, float y)
        return expr(x / y);

    /// Output implementation.
    /// \param out stream to write to
    /// \param arg value to write
    /// \return reference to stream
    template <typename charT, typename traits>
    static std::basic_ostream<charT, traits>& write(std::basic_ostream<charT, traits>& out, float arg)
        return out << arg;

    /// Input implementation.
    /// \param in stream to read from
    /// \param arg half to read into
    /// \return reference to stream
    template <typename charT, typename traits>
    static std::basic_istream<charT, traits>& read(std::basic_istream<charT, traits>& in, half& arg)
        float f;
        if (in >> f)
            arg = f;
        return in;

    /// Modulo implementation.
    /// \param x first operand
    /// \param y second operand
    /// \return Half-precision division remainder stored in single-precision
    static expr fmod(float x, float y)
        return expr(std::fmod(x, y));

    /// Remainder implementation.
    /// \param x first operand
    /// \param y second operand
    /// \return Half-precision division remainder stored in single-precision
    static expr remainder(float x, float y)
        return expr(std::remainder(x, y));
        if (builtin_isnan(x) || builtin_isnan(y))
            return expr(std::numeric_limits<float>::quiet_NaN());
        float ax = std::fabs(x), ay = std::fabs(y);
        if (ax >= 65536.0f || ay < std::ldexp(1.0f, -24))
            return expr(std::numeric_limits<float>::quiet_NaN());
        if (ay >= 65536.0f)
            return expr(x);
        if (ax == ay)
            return expr(builtin_signbit(x) ? -0.0f : 0.0f);
        ax = std::fmod(ax, ay + ay);
        float y2 = 0.5f * ay;
        if (ax > y2)
            ax -= ay;
            if (ax >= y2)
                ax -= ay;
        return expr(builtin_signbit(x) ? -ax : ax);

    /// Remainder implementation.
    /// \param x first operand
    /// \param y second operand
    /// \param quo address to store quotient bits at
    /// \return Half-precision division remainder stored in single-precision
    static expr remquo(float x, float y, int* quo)
        return expr(std::remquo(x, y, quo));
        if (builtin_isnan(x) || builtin_isnan(y))
            return expr(std::numeric_limits<float>::quiet_NaN());
        bool sign = builtin_signbit(x), qsign = static_cast<bool>(sign ^ builtin_signbit(y));
        float ax = std::fabs(x), ay = std::fabs(y);
        if (ax >= 65536.0f || ay < std::ldexp(1.0f, -24))
            return expr(std::numeric_limits<float>::quiet_NaN());
        if (ay >= 65536.0f)
            return expr(x);
        if (ax == ay)
            return *quo = qsign ? -1 : 1, expr(sign ? -0.0f : 0.0f);
        ax = std::fmod(ax, 8.0f * ay);
        int cquo = 0;
        if (ax >= 4.0f * ay)
            ax -= 4.0f * ay;
            cquo += 4;
        if (ax >= 2.0f * ay)
            ax -= 2.0f * ay;
            cquo += 2;
        float y2 = 0.5f * ay;
        if (ax > y2)
            ax -= ay;
            if (ax >= y2)
                ax -= ay;
        return *quo = qsign ? -cquo : cquo, expr(sign ? -ax : ax);

    /// Positive difference implementation.
    /// \param x first operand
    /// \param y second operand
    /// \return Positive difference stored in single-precision
    static expr fdim(float x, float y)
        return expr(std::fdim(x, y));
        return expr((x <= y) ? 0.0f : (x - y));

    /// Fused multiply-add implementation.
    /// \param x first operand
    /// \param y second operand
    /// \param z third operand
    /// \return \a x * \a y + \a z stored in single-precision
    static expr fma(float x, float y, float z)
        return expr(std::fma(x, y, z));
        return expr(x * y + z);

    /// Get NaN.
    /// \return Half-precision quiet NaN
    static half nanh()
        return half(binary, 0x7FFF);

    /// Exponential implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr exp(float arg)
        return expr(std::exp(arg));

    /// Exponential implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr expm1(float arg)
        return expr(std::expm1(arg));
        return expr(static_cast<float>(std::exp(static_cast<double>(arg)) - 1.0));

    /// Binary exponential implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr exp2(float arg)
        return expr(std::exp2(arg));
        return expr(static_cast<float>(std::exp(arg * 0.69314718055994530941723212145818)));

    /// Logarithm implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr log(float arg)
        return expr(std::log(arg));

    /// Common logarithm implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr log10(float arg)
        return expr(std::log10(arg));

    /// Logarithm implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr log1p(float arg)
        return expr(std::log1p(arg));
        return expr(static_cast<float>(std::log(1.0 + arg)));

    /// Binary logarithm implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr log2(float arg)
        return expr(std::log2(arg));
        return expr(static_cast<float>(std::log(static_cast<double>(arg)) * 1.4426950408889634073599246810019));

    /// Square root implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr sqrt(float arg)
        return expr(std::sqrt(arg));

    /// Cubic root implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr cbrt(float arg)
        return expr(std::cbrt(arg));
        if (builtin_isnan(arg) || builtin_isinf(arg))
            return expr(arg);
        return expr(builtin_signbit(arg) ? -static_cast<float>(std::pow(-static_cast<double>(arg), 1.0 / 3.0))
                                         : static_cast<float>(std::pow(static_cast<double>(arg), 1.0 / 3.0)));

    /// Hypotenuse implementation.
    /// \param x first argument
    /// \param y second argument
    /// \return function value stored in single-preicision
    static expr hypot(float x, float y)
        return expr(std::hypot(x, y));
        return expr((builtin_isinf(x) || builtin_isinf(y))
                ? std::numeric_limits<float>::infinity()
                : static_cast<float>(std::sqrt(static_cast<double>(x) * x + static_cast<double>(y) * y)));

    /// Power implementation.
    /// \param base value to exponentiate
    /// \param exp power to expontiate to
    /// \return function value stored in single-preicision
    static expr pow(float base, float exp)
        return expr(std::pow(base, exp));

    /// Sine implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr sin(float arg)
        return expr(std::sin(arg));

    /// Cosine implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr cos(float arg)
        return expr(std::cos(arg));

    /// Tan implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr tan(float arg)
        return expr(std::tan(arg));

    /// Arc sine implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr asin(float arg)
        return expr(std::asin(arg));

    /// Arc cosine implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr acos(float arg)
        return expr(std::acos(arg));

    /// Arc tangent implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr atan(float arg)
        return expr(std::atan(arg));

    /// Arc tangent implementation.
    /// \param x first argument
    /// \param y second argument
    /// \return function value stored in single-preicision
    static expr atan2(float x, float y)
        return expr(std::atan2(x, y));

    /// Hyperbolic sine implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr sinh(float arg)
        return expr(std::sinh(arg));

    /// Hyperbolic cosine implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr cosh(float arg)
        return expr(std::cosh(arg));

    /// Hyperbolic tangent implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr tanh(float arg)
        return expr(std::tanh(arg));

    /// Hyperbolic area sine implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr asinh(float arg)
        return expr(std::asinh(arg));
        return expr((arg == -std::numeric_limits<float>::infinity())
                ? arg
                : static_cast<float>(std::log(arg + std::sqrt(arg * arg + 1.0))));

    /// Hyperbolic area cosine implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr acosh(float arg)
        return expr(std::acosh(arg));
        return expr((arg < -1.0f) ? std::numeric_limits<float>::quiet_NaN()
                                  : static_cast<float>(std::log(arg + std::sqrt(arg * arg - 1.0))));

    /// Hyperbolic area tangent implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr atanh(float arg)
        return expr(std::atanh(arg));
        return expr(static_cast<float>(0.5 * std::log((1.0 + arg) / (1.0 - arg))));

    /// Error function implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr erf(float arg)
        return expr(std::erf(arg));
        return expr(static_cast<float>(erf(static_cast<double>(arg))));

    /// Complementary implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr erfc(float arg)
        return expr(std::erfc(arg));
        return expr(static_cast<float>(1.0 - erf(static_cast<double>(arg))));

    /// Gamma logarithm implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr lgamma(float arg)
        return expr(std::lgamma(arg));
        if (builtin_isinf(arg))
            return expr(std::numeric_limits<float>::infinity());
        if (arg < 0.0f)
            float i, f = std::modf(-arg, &i);
            if (f == 0.0f)
                return expr(std::numeric_limits<float>::infinity());
            return expr(static_cast<float>(1.1447298858494001741434273513531
                - std::log(std::abs(std::sin(3.1415926535897932384626433832795 * f))) - lgamma(1.0 - arg)));
        return expr(static_cast<float>(lgamma(static_cast<double>(arg))));

    /// Gamma implementation.
    /// \param arg function argument
    /// \return function value stored in single-preicision
    static expr tgamma(float arg)
        return expr(std::tgamma(arg));
        if (arg == 0.0f)
            return builtin_signbit(arg) ? expr(-std::numeric_limits<float>::infinity())
                                        : expr(std::numeric_limits<float>::infinity());
        if (arg < 0.0f)
            float i, f = std::modf(-arg, &i);
            if (f == 0.0f)
                return expr(std::numeric_limits<float>::quiet_NaN());
            double value = 3.1415926535897932384626433832795
                / (std::sin(3.1415926535897932384626433832795 * f) * std::exp(lgamma(1.0 - arg)));
            return expr(static_cast<float>((std::fmod(i, 2.0f) == 0.0f) ? -value : value));
        if (builtin_isinf(arg))
            return expr(arg);
        return expr(static_cast<float>(std::exp(lgamma(static_cast<double>(arg)))));

    /// Floor implementation.
    /// \param arg value to round
    /// \return rounded value
    static half floor(half arg)
        return half(binary, round_half<std::round_toward_neg_infinity>(arg.data_));

    /// Ceiling implementation.
    /// \param arg value to round
    /// \return rounded value
    static half ceil(half arg)
        return half(binary, round_half<std::round_toward_infinity>(arg.data_));

    /// Truncation implementation.
    /// \param arg value to round
    /// \return rounded value
    static half trunc(half arg)
        return half(binary, round_half<std::round_toward_zero>(arg.data_));

    /// Nearest integer implementation.
    /// \param arg value to round
    /// \return rounded value
    static half round(half arg)
        return half(binary, round_half_up(arg.data_));

    /// Nearest integer implementation.
    /// \param arg value to round
    /// \return rounded value
    static long lround(half arg)
        return detail::half2int_up<long>(arg.data_);

    /// Nearest integer implementation.
    /// \param arg value to round
    /// \return rounded value
    static half rint(half arg)
        return half(binary, round_half<half::round_style>(arg.data_));

    /// Nearest integer implementation.
    /// \param arg value to round
    /// \return rounded value
    static long lrint(half arg)
        return detail::half2int<half::round_style, long>(arg.data_);

    /// Nearest integer implementation.
    /// \param arg value to round
    /// \return rounded value
    static long long llround(half arg)
        return detail::half2int_up<long long>(arg.data_);

    /// Nearest integer implementation.
    /// \param arg value to round
    /// \return rounded value
    static long long llrint(half arg)
        return detail::half2int<half::round_style, long long>(arg.data_);

    /// Decompression implementation.
    /// \param arg number to decompress
    /// \param exp address to store exponent at
    /// \return normalized significant
    static half frexp(half arg, int* exp)
        int m = arg.data_ & 0x7FFF, e = -14;
        if (m >= 0x7C00 || !m)
            return *exp = 0, arg;
        for (; m < 0x400; m <<= 1, --e)
        return *exp = e + (m >> 10), half(binary, (arg.data_ & 0x8000) | 0x3800 | (m & 0x3FF));

    /// Decompression implementation.
    /// \param arg number to decompress
    /// \param iptr address to store integer part at
    /// \return fractional part
    static half modf(half arg, half* iptr)
        uint32_t e = arg.data_ & 0x7FFF;
        if (e >= 0x6400)
            return *iptr = arg, half(binary, arg.data_ & (0x8000U | -(e > 0x7C00)));
        if (e < 0x3C00)
            return iptr->data_ = arg.data_ & 0x8000, arg;
        e >>= 10;
        uint32_t mask = (1 << (25 - e)) - 1, m = arg.data_ & mask;
        iptr->data_ = arg.data_ & ~mask;
        if (!m)
            return half(binary, arg.data_ & 0x8000);
        for (; m < 0x400; m <<= 1, --e)
        return half(binary, static_cast<uint16>((arg.data_ & 0x8000) | (e << 10) | (m & 0x3FF)));

    /// Scaling implementation.
    /// \param arg number to scale
    /// \param exp power of two to scale by
    /// \return scaled number
    static half scalbln(half arg, long exp)
        uint32_t m = arg.data_ & 0x7FFF;
        if (m >= 0x7C00 || !m)
            return arg;
        for (; m < 0x400; m <<= 1, --exp)
        exp += m >> 10;
        uint16 value = arg.data_ & 0x8000;
        if (exp > 30)
            if (half::round_style == std::round_toward_zero)
                value |= 0x7BFF;
            else if (half::round_style == std::round_toward_infinity)
                value |= 0x7C00 - (value >> 15);
            else if (half::round_style == std::round_toward_neg_infinity)
                value |= 0x7BFF + (value >> 15);
                value |= 0x7C00;
        else if (exp > 0)
            value |= (exp << 10) | (m & 0x3FF);
        else if (exp > -11)
            m = (m & 0x3FF) | 0x400;
            if (half::round_style == std::round_to_nearest)
                m += 1 << -exp;
                m -= (m >> (1 - exp)) & 1;
            else if (half::round_style == std::round_toward_infinity)
                m += ((value >> 15) - 1) & ((1 << (1 - exp)) - 1U);
            else if (half::round_style == std::round_toward_neg_infinity)
                m += -(value >> 15) & ((1 << (1 - exp)) - 1U);
            value |= m >> (1 - exp);
        else if (half::round_style == std::round_toward_infinity)
            value -= (value >> 15) - 1;
        else if (half::round_style == std::round_toward_neg_infinity)
            value += value >> 15;
        return half(binary, value);

    /// Exponent implementation.
    /// \param arg number to query
    /// \return floating point exponent
    static int ilogb(half arg)
        int abs = arg.data_ & 0x7FFF;
        if (!abs)
            return FP_ILOGB0;
        if (abs < 0x7C00)
            int exp = (abs >> 10) - 15;
            if (abs < 0x400)
                for (; abs < 0x200; abs <<= 1, --exp)
            return exp;
        if (abs > 0x7C00)
            return FP_ILOGBNAN;
        return INT_MAX;

    /// Exponent implementation.
    /// \param arg number to query
    /// \return floating point exponent
    static half logb(half arg)
        int abs = arg.data_ & 0x7FFF;
        if (!abs)
            return half(binary, 0xFC00);
        if (abs < 0x7C00)
            int exp = (abs >> 10) - 15;
            if (abs < 0x400)
                for (; abs < 0x200; abs <<= 1, --exp)
            uint16 bits = (exp < 0) << 15;
            if (exp)
                uint32_t m = std::abs(exp) << 6, e = 18;
                for (; m < 0x400; m <<= 1, --e)
                bits |= (e << 10) + m;
            return half(binary, bits);
        if (abs > 0x7C00)
            return arg;
        return half(binary, 0x7C00);

    /// Enumeration implementation.
    /// \param from number to increase/decrease
    /// \param to direction to enumerate into
    /// \return next representable number
    static half nextafter(half from, half to)
        uint16 fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF;
        if (fabs > 0x7C00)
            return from;
        if (tabs > 0x7C00 || from.data_ == to.data_ || !(fabs | tabs))
            return to;
        if (!fabs)
            return half(binary, (to.data_ & 0x8000) + 1);
        bool lt = ((fabs == from.data_) ? static_cast<int>(fabs) : -static_cast<int>(fabs))
            < ((tabs == to.data_) ? static_cast<int>(tabs) : -static_cast<int>(tabs));
        return half(binary, from.data_ + (((from.data_ >> 15) ^ static_cast<unsigned>(lt)) << 1) - 1);

    /// Enumeration implementation.
    /// \param from number to increase/decrease
    /// \param to direction to enumerate into
    /// \return next representable number
    static half nexttoward(half from, long double to)
        if (isnan(from))
            return from;
        long double lfrom = static_cast<long double>(from);
        if (builtin_isnan(to) || lfrom == to)
            return half(static_cast<float>(to));
        if (!(from.data_ & 0x7FFF))
            return half(binary, (static_cast<detail::uint16>(builtin_signbit(to)) << 15) + 1);
        return half(binary, from.data_ + (((from.data_ >> 15) ^ static_cast<unsigned>(lfrom < to)) << 1) - 1);

    /// Sign implementation
    /// \param x first operand
    /// \param y second operand
    /// \return composed value
    static half copysign(half x, half y)
        return half(binary, x.data_ ^ ((x.data_ ^ y.data_) & 0x8000));

    /// Classification implementation.
    /// \param arg value to classify
    /// \retval true if infinite number
    /// \retval false else
    static int fpclassify(half arg)
        uint32_t abs = arg.data_ & 0x7FFF;
        return abs
            ? ((abs > 0x3FF) ? ((abs >= 0x7C00) ? ((abs > 0x7C00) ? FP_NAN : FP_INFINITE) : FP_NORMAL) : FP_SUBNORMAL)
            : FP_ZERO;

    /// Classification implementation.
    /// \param arg value to classify
    /// \retval true if finite number
    /// \retval false else
    static bool isfinite(half arg)
        return (arg.data_ & 0x7C00) != 0x7C00;

    /// Classification implementation.
    /// \param arg value to classify
    /// \retval true if infinite number
    /// \retval false else
    static bool isinf(half arg)
        return (arg.data_ & 0x7FFF) == 0x7C00;

    /// Classification implementation.
    /// \param arg value to classify
    /// \retval true if not a number
    /// \retval false else
    static bool isnan(half arg)
        return (arg.data_ & 0x7FFF) > 0x7C00;

    /// Classification implementation.
    /// \param arg value to classify
    /// \retval true if normal number
    /// \retval false else
    static bool isnormal(half arg)
        return ((arg.data_ & 0x7C00) != 0) & ((arg.data_ & 0x7C00) != 0x7C00);

    /// Sign bit implementation.
    /// \param arg value to check
    /// \retval true if signed
    /// \retval false if unsigned
    static bool signbit(half arg)
        return (arg.data_ & 0x8000) != 0;

    /// Comparison implementation.
    /// \param x first operand
    /// \param y second operand
    /// \retval true if operands equal
    /// \retval false else
    static bool isequal(half x, half y)
        return (x.data_ == y.data_ || !((x.data_ | y.data_) & 0x7FFF)) && !isnan(x);

    /// Comparison implementation.
    /// \param x first operand
    /// \param y second operand
    /// \retval true if operands not equal
    /// \retval false else
    static bool isnotequal(half x, half y)
        return (x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF)) || isnan(x);

    /// Comparison implementation.
    /// \param x first operand
    /// \param y second operand
    /// \retval true if \a x > \a y
    /// \retval false else
    static bool isgreater(half x, half y)
        int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
        return xabs <= 0x7C00 && yabs <= 0x7C00
            && (((xabs == x.data_) ? xabs : -xabs) > ((yabs == y.data_) ? yabs : -yabs));

    /// Comparison implementation.
    /// \param x first operand
    /// \param y second operand
    /// \retval true if \a x >= \a y
    /// \retval false else
    static bool isgreaterequal(half x, half y)
        int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
        return xabs <= 0x7C00 && yabs <= 0x7C00
            && (((xabs == x.data_) ? xabs : -xabs) >= ((yabs == y.data_) ? yabs : -yabs));

    /// Comparison implementation.
    /// \param x first operand
    /// \param y second operand
    /// \retval true if \a x < \a y
    /// \retval false else
    static bool isless(half x, half y)
        int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
        return xabs <= 0x7C00 && yabs <= 0x7C00
            && (((xabs == x.data_) ? xabs : -xabs) < ((yabs == y.data_) ? yabs : -yabs));

    /// Comparison implementation.
    /// \param x first operand
    /// \param y second operand
    /// \retval true if \a x <= \a y
    /// \retval false else
    static bool islessequal(half x, half y)
        int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
        return xabs <= 0x7C00 && yabs <= 0x7C00
            && (((xabs == x.data_) ? xabs : -xabs) <= ((yabs == y.data_) ? yabs : -yabs));

    /// Comparison implementation.
    /// \param x first operand
    /// \param y second operand
    /// \retval true if either \a x > \a y nor \a x < \a y
    /// \retval false else
    static bool islessgreater(half x, half y)
        int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
        if (xabs > 0x7C00 || yabs > 0x7C00)
            return false;
        int a = (xabs == x.data_) ? xabs : -xabs, b = (yabs == y.data_) ? yabs : -yabs;
        return a < b || a > b;

    /// Comparison implementation.
    /// \param x first operand
    /// \param y second operand
    /// \retval true if operand unordered
    /// \retval false else
    static bool isunordered(half x, half y)
        return isnan(x) || isnan(y);

    static double erf(double arg)
        if (builtin_isinf(arg))
            return (arg < 0.0) ? -1.0 : 1.0;
        double x2 = arg * arg, ax2 = 0.147 * x2,
               value = std::sqrt(1.0 - std::exp(-x2 * (1.2732395447351626861510701069801 + ax2) / (1.0 + ax2)));
        return builtin_signbit(arg) ? -value : value;

    static double lgamma(double arg)
        double v = 1.0;
        for (; arg < 8.0; ++arg)
            v *= arg;
        double w = 1.0 / (arg * arg);
        return (((((((-0.02955065359477124183006535947712 * w + 0.00641025641025641025641025641026) * w
                        + -0.00191752691752691752691752691753)
                           * w
                       + 8.4175084175084175084175084175084e-4)
                          * w
                      + -5.952380952380952380952380952381e-4)
                         * w
                     + 7.9365079365079365079365079365079e-4)
                        * w
                    + -0.00277777777777777777777777777778)
                       * w
                   + 0.08333333333333333333333333333333)
            / arg
            + 0.91893853320467274178032973640562 - std::log(v) - arg + (arg - 0.5) * std::log(arg);

/// Wrapper for unary half-precision functions needing specialization for individual argument types.
/// \tparam T argument type
template <typename T>
struct unary_specialized
    /// Negation implementation.
    /// \param arg value to negate
    /// \return negated value
    static HALF_CONSTEXPR half negate(half arg)
        return half(binary, arg.data_ ^ 0x8000);

    /// Absolute value implementation.
    /// \param arg function argument
    /// \return absolute value
    static half fabs(half arg)
        return half(binary, arg.data_ & 0x7FFF);
template <>
struct unary_specialized<expr>
    static HALF_CONSTEXPR expr negate(float arg)
        return expr(-arg);
    static expr fabs(float arg)
        return expr(std::fabs(arg));

/// Wrapper for binary half-precision functions needing specialization for individual argument types.
/// \tparam T first argument type
/// \tparam U first argument type
template <typename T, typename U>
struct binary_specialized
    /// Minimum implementation.
    /// \param x first operand
    /// \param y second operand
    /// \return minimum value
    static expr fmin(float x, float y)
        return expr(std::fmin(x, y));
        if (builtin_isnan(x))
            return expr(y);
        if (builtin_isnan(y))
            return expr(x);
        return expr(std::min(x, y));

    /// Maximum implementation.
    /// \param x first operand
    /// \param y second operand
    /// \return maximum value
    static expr fmax(float x, float y)
        return expr(std::fmax(x, y));
        if (builtin_isnan(x))
            return expr(y);
        if (builtin_isnan(y))
            return expr(x);
        return expr(std::max(x, y));
template <>
struct binary_specialized<half, half>
    static half fmin(half x, half y)
        int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
        if (xabs > 0x7C00)
            return y;
        if (yabs > 0x7C00)
            return x;
        return (((xabs == x.data_) ? xabs : -xabs) > ((yabs == y.data_) ? yabs : -yabs)) ? y : x;
    static half fmax(half x, half y)
        int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
        if (xabs > 0x7C00)
            return y;
        if (yabs > 0x7C00)
            return x;
        return (((xabs == x.data_) ? xabs : -xabs) < ((yabs == y.data_) ? yabs : -yabs)) ? y : x;

/// Helper class for half casts.
/// This class template has to be specialized for all valid cast argument to define an appropriate static `cast` member
/// function and a corresponding `type` member denoting its return type.
/// \tparam T destination type
/// \tparam U source type
/// \tparam R rounding mode to use
template <typename T, typename U, std::float_round_style R = (std::float_round_style)(HALF_ROUND_STYLE)>
struct half_caster
template <typename U, std::float_round_style R>
struct half_caster<half, U, R>
    static_assert(std::is_arithmetic<U>::value, "half_cast from non-arithmetic type unsupported");

    static half cast(U arg)
        return cast_impl(arg, is_float<U>());

    static half cast_impl(U arg, true_type)
        return half(binary, float2half<R>(arg));
    static half cast_impl(U arg, false_type)
        return half(binary, int2half<R>(arg));
template <typename T, std::float_round_style R>
struct half_caster<T, half, R>
    static_assert(std::is_arithmetic<T>::value, "half_cast to non-arithmetic type unsupported");

    static T cast(half arg)
        return cast_impl(arg, is_float<T>());

    static T cast_impl(half arg, true_type)
        return half2float<T>(arg.data_);
    static T cast_impl(half arg, false_type)
        return half2int<R, T>(arg.data_);
template <typename T, std::float_round_style R>
struct half_caster<T, expr, R>
    static_assert(std::is_arithmetic<T>::value, "half_cast to non-arithmetic type unsupported");

    static T cast(expr arg)
        return cast_impl(arg, is_float<T>());

    static T cast_impl(float arg, true_type)
        return static_cast<T>(arg);
    static T cast_impl(half arg, false_type)
        return half2int<R, T>(arg.data_);
template <std::float_round_style R>
struct half_caster<half, half, R>
    static half cast(half arg)
        return arg;
template <std::float_round_style R>
struct half_caster<half, expr, R> : half_caster<half, half, R>

/// \name Comparison operators
/// \{

/// Comparison for equality.
/// \param x first operand
/// \param y second operand
/// \retval true if operands equal
/// \retval false else
template <typename T, typename U>
typename enable<bool, T, U>::type operator==(T x, U y)
    return functions::isequal(x, y);

/// Comparison for inequality.
/// \param x first operand
/// \param y second operand
/// \retval true if operands not equal
/// \retval false else
template <typename T, typename U>
typename enable<bool, T, U>::type operator!=(T x, U y)
    return functions::isnotequal(x, y);

/// Comparison for less than.
/// \param x first operand
/// \param y second operand
/// \retval true if \a x less than \a y
/// \retval false else
template <typename T, typename U>
typename enable<bool, T, U>::type operator<(T x, U y)
    return functions::isless(x, y);

/// Comparison for greater than.
/// \param x first operand
/// \param y second operand
/// \retval true if \a x greater than \a y
/// \retval false else
template <typename T, typename U>
typename enable<bool, T, U>::type operator>(T x, U y)
    return functions::isgreater(x, y);

/// Comparison for less equal.
/// \param x first operand
/// \param y second operand
/// \retval true if \a x less equal \a y
/// \retval false else
template <typename T, typename U>
typename enable<bool, T, U>::type operator<=(T x, U y)
    return functions::islessequal(x, y);

/// Comparison for greater equal.
/// \param x first operand
/// \param y second operand
/// \retval true if \a x greater equal \a y
/// \retval false else
template <typename T, typename U>
typename enable<bool, T, U>::type operator>=(T x, U y)
    return functions::isgreaterequal(x, y);

/// \}
/// \name Arithmetic operators
/// \{

/// Add halfs.
/// \param x left operand
/// \param y right operand
/// \return sum of half expressions
template <typename T, typename U>
typename enable<expr, T, U>::type operator+(T x, U y)
    return functions::plus(x, y);

/// Subtract halfs.
/// \param x left operand
/// \param y right operand
/// \return difference of half expressions
template <typename T, typename U>
typename enable<expr, T, U>::type operator-(T x, U y)
    return functions::minus(x, y);

/// Multiply halfs.
/// \param x left operand
/// \param y right operand
/// \return product of half expressions
template <typename T, typename U>
typename enable<expr, T, U>::type operator*(T x, U y)
    return functions::multiplies(x, y);

/// Divide halfs.
/// \param x left operand
/// \param y right operand
/// \return quotient of half expressions
template <typename T, typename U>
typename enable<expr, T, U>::type operator/(T x, U y)
    return functions::divides(x, y);

/// Identity.
/// \param arg operand
/// \return uncahnged operand
template <typename T>
HALF_CONSTEXPR typename enable<T, T>::type operator+(T arg)
    return arg;

/// Negation.
/// \param arg operand
/// \return negated operand
template <typename T>
HALF_CONSTEXPR typename enable<T, T>::type operator-(T arg)
    return unary_specialized<T>::negate(arg);

/// \}
/// \name Input and output
/// \{

/// Output operator.
/// \param out output stream to write into
/// \param arg half expression to write
/// \return reference to output stream
template <typename T, typename charT, typename traits>
typename enable<std::basic_ostream<charT, traits>&, T>::type operator<<(std::basic_ostream<charT, traits>& out, T arg)
    return functions::write(out, arg);

/// Input operator.
/// \param in input stream to read from
/// \param arg half to read into
/// \return reference to input stream
template <typename charT, typename traits>
std::basic_istream<charT, traits>& operator>>(std::basic_istream<charT, traits>& in, half& arg)
    return functions::read(in, arg);

/// \}
/// \name Basic mathematical operations
/// \{

/// Absolute value.
/// \param arg operand
/// \return absolute value of \a arg
//		template<typename T> typename enable<T,T>::type abs(T arg) { return unary_specialized<T>::fabs(arg); }
inline half abs(half arg)
    return unary_specialized<half>::fabs(arg);
inline expr abs(expr arg)
    return unary_specialized<expr>::fabs(arg);

/// Absolute value.
/// \param arg operand
/// \return absolute value of \a arg
//		template<typename T> typename enable<T,T>::type fabs(T arg) { return unary_specialized<T>::fabs(arg); }
inline half fabs(half arg)
    return unary_specialized<half>::fabs(arg);
inline expr fabs(expr arg)
    return unary_specialized<expr>::fabs(arg);

/// Remainder of division.
/// \param x first operand
/// \param y second operand
/// \return remainder of floating point division.
//		template<typename T,typename U> typename enable<expr,T,U>::type fmod(T x, U y) { return functions::fmod(x, y); }
inline expr fmod(half x, half y)
    return functions::fmod(x, y);
inline expr fmod(half x, expr y)
    return functions::fmod(x, y);
inline expr fmod(expr x, half y)
    return functions::fmod(x, y);
inline expr fmod(expr x, expr y)
    return functions::fmod(x, y);

/// Remainder of division.
/// \param x first operand
/// \param y second operand
/// \return remainder of floating point division.
//		template<typename T,typename U> typename enable<expr,T,U>::type remainder(T x, U y) { return
// functions::remainder(x, y); }
inline expr remainder(half x, half y)
    return functions::remainder(x, y);
inline expr remainder(half x, expr y)
    return functions::remainder(x, y);
inline expr remainder(expr x, half y)
    return functions::remainder(x, y);
inline expr remainder(expr x, expr y)
    return functions::remainder(x, y);

/// Remainder of division.
/// \param x first operand
/// \param y second operand
/// \param quo address to store some bits of quotient at
/// \return remainder of floating point division.
//		template<typename T,typename U> typename enable<expr,T,U>::type remquo(T x, U y, int *quo) { return
// functions::remquo(x, y, quo); }
inline expr remquo(half x, half y, int* quo)
    return functions::remquo(x, y, quo);
inline expr remquo(half x, expr y, int* quo)
    return functions::remquo(x, y, quo);
inline expr remquo(expr x, half y, int* quo)
    return functions::remquo(x, y, quo);
inline expr remquo(expr x, expr y, int* quo)
    return functions::remquo(x, y, quo);

/// Fused multiply add.
/// \param x first operand
/// \param y second operand
/// \param z third operand
/// \return ( \a x * \a y ) + \a z rounded as one operation.
//		template<typename T,typename U,typename V> typename enable<expr,T,U,V>::type fma(T x, U y, V z) { return
// functions::fma(x, y, z); }
inline expr fma(half x, half y, half z)
    return functions::fma(x, y, z);
inline expr fma(half x, half y, expr z)
    return functions::fma(x, y, z);
inline expr fma(half x, expr y, half z)
    return functions::fma(x, y, z);
inline expr fma(half x, expr y, expr z)
    return functions::fma(x, y, z);
inline expr fma(expr x, half y, half z)
    return functions::fma(x, y, z);
inline expr fma(expr x, half y, expr z)
    return functions::fma(x, y, z);
inline expr fma(expr x, expr y, half z)
    return functions::fma(x, y, z);
inline expr fma(expr x, expr y, expr z)
    return functions::fma(x, y, z);

/// Maximum of half expressions.
/// \param x first operand
/// \param y second operand
/// \return maximum of operands
//		template<typename T,typename U> typename result<T,U>::type fmax(T x, U y) { return
// binary_specialized<T,U>::fmax(x, y); }
inline half fmax(half x, half y)
    return binary_specialized<half, half>::fmax(x, y);
inline expr fmax(half x, expr y)
    return binary_specialized<half, expr>::fmax(x, y);
inline expr fmax(expr x, half y)
    return binary_specialized<expr, half>::fmax(x, y);
inline expr fmax(expr x, expr y)
    return binary_specialized<expr, expr>::fmax(x, y);

/// Minimum of half expressions.
/// \param x first operand
/// \param y second operand
/// \return minimum of operands
//		template<typename T,typename U> typename result<T,U>::type fmin(T x, U y) { return
// binary_specialized<T,U>::fmin(x, y); }
inline half fmin(half x, half y)
    return binary_specialized<half, half>::fmin(x, y);
inline expr fmin(half x, expr y)
    return binary_specialized<half, expr>::fmin(x, y);
inline expr fmin(expr x, half y)
    return binary_specialized<expr, half>::fmin(x, y);
inline expr fmin(expr x, expr y)
    return binary_specialized<expr, expr>::fmin(x, y);

/// Positive difference.
/// \param x first operand
/// \param y second operand
/// \return \a x - \a y or 0 if difference negative
//		template<typename T,typename U> typename enable<expr,T,U>::type fdim(T x, U y) { return functions::fdim(x, y); }
inline expr fdim(half x, half y)
    return functions::fdim(x, y);
inline expr fdim(half x, expr y)
    return functions::fdim(x, y);
inline expr fdim(expr x, half y)
    return functions::fdim(x, y);
inline expr fdim(expr x, expr y)
    return functions::fdim(x, y);

/// Get NaN value.
/// \return quiet NaN
inline half nanh(const char*)
    return functions::nanh();

/// \}
/// \name Exponential functions
/// \{

/// Exponential function.
/// \param arg function argument
/// \return e raised to \a arg
//		template<typename T> typename enable<expr,T>::type exp(T arg) { return functions::exp(arg); }
inline expr exp(half arg)
    return functions::exp(arg);
inline expr exp(expr arg)
    return functions::exp(arg);

/// Exponential minus one.
/// \param arg function argument
/// \return e raised to \a arg subtracted by 1
//		template<typename T> typename enable<expr,T>::type expm1(T arg) { return functions::expm1(arg); }
inline expr expm1(half arg)
    return functions::expm1(arg);
inline expr expm1(expr arg)
    return functions::expm1(arg);

/// Binary exponential.
/// \param arg function argument
/// \return 2 raised to \a arg
//		template<typename T> typename enable<expr,T>::type exp2(T arg) { return functions::exp2(arg); }
inline expr exp2(half arg)
    return functions::exp2(arg);
inline expr exp2(expr arg)
    return functions::exp2(arg);

/// Natural logorithm.
/// \param arg function argument
/// \return logarithm of \a arg to base e
//		template<typename T> typename enable<expr,T>::type log(T arg) { return functions::log(arg); }
inline expr log(half arg)
    return functions::log(arg);
inline expr log(expr arg)
    return functions::log(arg);

/// Common logorithm.
/// \param arg function argument
/// \return logarithm of \a arg to base 10
//		template<typename T> typename enable<expr,T>::type log10(T arg) { return functions::log10(arg); }
inline expr log10(half arg)
    return functions::log10(arg);
inline expr log10(expr arg)
    return functions::log10(arg);

/// Natural logorithm.
/// \param arg function argument
/// \return logarithm of \a arg plus 1 to base e
//		template<typename T> typename enable<expr,T>::type log1p(T arg) { return functions::log1p(arg); }
inline expr log1p(half arg)
    return functions::log1p(arg);
inline expr log1p(expr arg)
    return functions::log1p(arg);

/// Binary logorithm.
/// \param arg function argument
/// \return logarithm of \a arg to base 2
//		template<typename T> typename enable<expr,T>::type log2(T arg) { return functions::log2(arg); }
inline expr log2(half arg)
    return functions::log2(arg);
inline expr log2(expr arg)
    return functions::log2(arg);

/// \}
/// \name Power functions
/// \{

/// Square root.
/// \param arg function argument
/// \return square root of \a arg
//		template<typename T> typename enable<expr,T>::type sqrt(T arg) { return functions::sqrt(arg); }
inline expr sqrt(half arg)
    return functions::sqrt(arg);
inline expr sqrt(expr arg)
    return functions::sqrt(arg);

/// Cubic root.
/// \param arg function argument
/// \return cubic root of \a arg
//		template<typename T> typename enable<expr,T>::type cbrt(T arg) { return functions::cbrt(arg); }
inline expr cbrt(half arg)
    return functions::cbrt(arg);
inline expr cbrt(expr arg)
    return functions::cbrt(arg);

/// Hypotenuse function.
/// \param x first argument
/// \param y second argument
/// \return square root of sum of squares without internal over- or underflows
//		template<typename T,typename U> typename enable<expr,T,U>::type hypot(T x, U y) { return functions::hypot(x, y);
inline expr hypot(half x, half y)
    return functions::hypot(x, y);
inline expr hypot(half x, expr y)
    return functions::hypot(x, y);
inline expr hypot(expr x, half y)
    return functions::hypot(x, y);
inline expr hypot(expr x, expr y)
    return functions::hypot(x, y);

/// Power function.
/// \param base first argument
/// \param exp second argument
/// \return \a base raised to \a exp
//		template<typename T,typename U> typename enable<expr,T,U>::type pow(T base, U exp) { return functions::pow(base,
// exp); }
inline expr pow(half base, half exp)
    return functions::pow(base, exp);
inline expr pow(half base, expr exp)
    return functions::pow(base, exp);
inline expr pow(expr base, half exp)
    return functions::pow(base, exp);
inline expr pow(expr base, expr exp)
    return functions::pow(base, exp);

/// \}
/// \name Trigonometric functions
/// \{

/// Sine function.
/// \param arg function argument
/// \return sine value of \a arg
//		template<typename T> typename enable<expr,T>::type sin(T arg) { return functions::sin(arg); }
inline expr sin(half arg)
    return functions::sin(arg);
inline expr sin(expr arg)
    return functions::sin(arg);

/// Cosine function.
/// \param arg function argument
/// \return cosine value of \a arg
//		template<typename T> typename enable<expr,T>::type cos(T arg) { return functions::cos(arg); }
inline expr cos(half arg)
    return functions::cos(arg);
inline expr cos(expr arg)
    return functions::cos(arg);

/// Tangent function.
/// \param arg function argument
/// \return tangent value of \a arg
//		template<typename T> typename enable<expr,T>::type tan(T arg) { return functions::tan(arg); }
inline expr tan(half arg)
    return functions::tan(arg);
inline expr tan(expr arg)
    return functions::tan(arg);

/// Arc sine.
/// \param arg function argument
/// \return arc sine value of \a arg
//		template<typename T> typename enable<expr,T>::type asin(T arg) { return functions::asin(arg); }
inline expr asin(half arg)
    return functions::asin(arg);
inline expr asin(expr arg)
    return functions::asin(arg);

/// Arc cosine function.
/// \param arg function argument
/// \return arc cosine value of \a arg
//		template<typename T> typename enable<expr,T>::type acos(T arg) { return functions::acos(arg); }
inline expr acos(half arg)
    return functions::acos(arg);
inline expr acos(expr arg)
    return functions::acos(arg);

/// Arc tangent function.
/// \param arg function argument
/// \return arc tangent value of \a arg
//		template<typename T> typename enable<expr,T>::type atan(T arg) { return functions::atan(arg); }
inline expr atan(half arg)
    return functions::atan(arg);
inline expr atan(expr arg)
    return functions::atan(arg);

/// Arc tangent function.
/// \param x first argument
/// \param y second argument
/// \return arc tangent value
//		template<typename T,typename U> typename enable<expr,T,U>::type atan2(T x, U y) { return functions::atan2(x, y);
inline expr atan2(half x, half y)
    return functions::atan2(x, y);
inline expr atan2(half x, expr y)
    return functions::atan2(x, y);
inline expr atan2(expr x, half y)
    return functions::atan2(x, y);
inline expr atan2(expr x, expr y)
    return functions::atan2(x, y);

/// \}
/// \name Hyperbolic functions
/// \{

/// Hyperbolic sine.
/// \param arg function argument
/// \return hyperbolic sine value of \a arg
//		template<typename T> typename enable<expr,T>::type sinh(T arg) { return functions::sinh(arg); }
inline expr sinh(half arg)
    return functions::sinh(arg);
inline expr sinh(expr arg)
    return functions::sinh(arg);

/// Hyperbolic cosine.
/// \param arg function argument
/// \return hyperbolic cosine value of \a arg
//		template<typename T> typename enable<expr,T>::type cosh(T arg) { return functions::cosh(arg); }
inline expr cosh(half arg)
    return functions::cosh(arg);
inline expr cosh(expr arg)
    return functions::cosh(arg);

/// Hyperbolic tangent.
/// \param arg function argument
/// \return hyperbolic tangent value of \a arg
//		template<typename T> typename enable<expr,T>::type tanh(T arg) { return functions::tanh(arg); }
inline expr tanh(half arg)
    return functions::tanh(arg);
inline expr tanh(expr arg)
    return functions::tanh(arg);

/// Hyperbolic area sine.
/// \param arg function argument
/// \return area sine value of \a arg
//		template<typename T> typename enable<expr,T>::type asinh(T arg) { return functions::asinh(arg); }
inline expr asinh(half arg)
    return functions::asinh(arg);
inline expr asinh(expr arg)
    return functions::asinh(arg);

/// Hyperbolic area cosine.
/// \param arg function argument
/// \return area cosine value of \a arg
//		template<typename T> typename enable<expr,T>::type acosh(T arg) { return functions::acosh(arg); }
inline expr acosh(half arg)
    return functions::acosh(arg);
inline expr acosh(expr arg)
    return functions::acosh(arg);

/// Hyperbolic area tangent.
/// \param arg function argument
/// \return area tangent value of \a arg
//		template<typename T> typename enable<expr,T>::type atanh(T arg) { return functions::atanh(arg); }
inline expr atanh(half arg)
    return functions::atanh(arg);
inline expr atanh(expr arg)
    return functions::atanh(arg);

/// \}
/// \name Error and gamma functions
/// \{

/// Error function.
/// \param arg function argument
/// \return error function value of \a arg
//		template<typename T> typename enable<expr,T>::type erf(T arg) { return functions::erf(arg); }
inline expr erf(half arg)
    return functions::erf(arg);
inline expr erf(expr arg)
    return functions::erf(arg);

/// Complementary error function.
/// \param arg function argument
/// \return 1 minus error function value of \a arg
//		template<typename T> typename enable<expr,T>::type erfc(T arg) { return functions::erfc(arg); }
inline expr erfc(half arg)
    return functions::erfc(arg);
inline expr erfc(expr arg)
    return functions::erfc(arg);

/// Natural logarithm of gamma function.
/// \param arg function argument
/// \return natural logarith of gamma function for \a arg
//		template<typename T> typename enable<expr,T>::type lgamma(T arg) { return functions::lgamma(arg); }
inline expr lgamma(half arg)
    return functions::lgamma(arg);
inline expr lgamma(expr arg)
    return functions::lgamma(arg);

/// Gamma function.
/// \param arg function argument
/// \return gamma function value of \a arg
//		template<typename T> typename enable<expr,T>::type tgamma(T arg) { return functions::tgamma(arg); }
inline expr tgamma(half arg)
    return functions::tgamma(arg);
inline expr tgamma(expr arg)
    return functions::tgamma(arg);

/// \}
/// \name Rounding
/// \{

/// Nearest integer not less than half value.
/// \param arg half to round
/// \return nearest integer not less than \a arg
//		template<typename T> typename enable<half,T>::type ceil(T arg) { return functions::ceil(arg); }
inline half ceil(half arg)
    return functions::ceil(arg);
inline half ceil(expr arg)
    return functions::ceil(arg);

/// Nearest integer not greater than half value.
/// \param arg half to round
/// \return nearest integer not greater than \a arg
//		template<typename T> typename enable<half,T>::type floor(T arg) { return functions::floor(arg); }
inline half floor(half arg)
    return functions::floor(arg);
inline half floor(expr arg)
    return functions::floor(arg);

/// Nearest integer not greater in magnitude than half value.
/// \param arg half to round
/// \return nearest integer not greater in magnitude than \a arg
//		template<typename T> typename enable<half,T>::type trunc(T arg) { return functions::trunc(arg); }
inline half trunc(half arg)
    return functions::trunc(arg);
inline half trunc(expr arg)
    return functions::trunc(arg);

/// Nearest integer.
/// \param arg half to round
/// \return nearest integer, rounded away from zero in half-way cases
//		template<typename T> typename enable<half,T>::type round(T arg) { return functions::round(arg); }
inline half round(half arg)
    return functions::round(arg);
inline half round(expr arg)
    return functions::round(arg);

/// Nearest integer.
/// \param arg half to round
/// \return nearest integer, rounded away from zero in half-way cases
//		template<typename T> typename enable<long,T>::type lround(T arg) { return functions::lround(arg); }
inline long lround(half arg)
    return functions::lround(arg);
inline long lround(expr arg)
    return functions::lround(arg);

/// Nearest integer using half's internal rounding mode.
/// \param arg half expression to round
/// \return nearest integer using default rounding mode
//		template<typename T> typename enable<half,T>::type nearbyint(T arg) { return functions::nearbyint(arg); }
inline half nearbyint(half arg)
    return functions::rint(arg);
inline half nearbyint(expr arg)
    return functions::rint(arg);

/// Nearest integer using half's internal rounding mode.
/// \param arg half expression to round
/// \return nearest integer using default rounding mode
//		template<typename T> typename enable<half,T>::type rint(T arg) { return functions::rint(arg); }
inline half rint(half arg)
    return functions::rint(arg);
inline half rint(expr arg)
    return functions::rint(arg);

/// Nearest integer using half's internal rounding mode.
/// \param arg half expression to round
/// \return nearest integer using default rounding mode
//		template<typename T> typename enable<long,T>::type lrint(T arg) { return functions::lrint(arg); }
inline long lrint(half arg)
    return functions::lrint(arg);
inline long lrint(expr arg)
    return functions::lrint(arg);
/// Nearest integer.
/// \param arg half to round
/// \return nearest integer, rounded away from zero in half-way cases
//		template<typename T> typename enable<long long,T>::type llround(T arg) { return functions::llround(arg); }
inline long long llround(half arg)
    return functions::llround(arg);
inline long long llround(expr arg)
    return functions::llround(arg);

/// Nearest integer using half's internal rounding mode.
/// \param arg half expression to round
/// \return nearest integer using default rounding mode
//		template<typename T> typename enable<long long,T>::type llrint(T arg) { return functions::llrint(arg); }
inline long long llrint(half arg)
    return functions::llrint(arg);
inline long long llrint(expr arg)
    return functions::llrint(arg);

/// \}
/// \name Floating point manipulation
/// \{

/// Decompress floating point number.
/// \param arg number to decompress
/// \param exp address to store exponent at
/// \return significant in range [0.5, 1)
//		template<typename T> typename enable<half,T>::type frexp(T arg, int *exp) { return functions::frexp(arg, exp); }
inline half frexp(half arg, int* exp)
    return functions::frexp(arg, exp);
inline half frexp(expr arg, int* exp)
    return functions::frexp(arg, exp);

/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp
//		template<typename T> typename enable<half,T>::type ldexp(T arg, int exp) { return functions::scalbln(arg, exp);
inline half ldexp(half arg, int exp)
    return functions::scalbln(arg, exp);
inline half ldexp(expr arg, int exp)
    return functions::scalbln(arg, exp);

/// Extract integer and fractional parts.
/// \param arg number to decompress
/// \param iptr address to store integer part at
/// \return fractional part
//		template<typename T> typename enable<half,T>::type modf(T arg, half *iptr) { return functions::modf(arg, iptr);
inline half modf(half arg, half* iptr)
    return functions::modf(arg, iptr);
inline half modf(expr arg, half* iptr)
    return functions::modf(arg, iptr);

/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp
//		template<typename T> typename enable<half,T>::type scalbn(T arg, int exp) { return functions::scalbln(arg, exp);
inline half scalbn(half arg, int exp)
    return functions::scalbln(arg, exp);
inline half scalbn(expr arg, int exp)
    return functions::scalbln(arg, exp);

/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp
//		template<typename T> typename enable<half,T>::type scalbln(T arg, long exp) { return functions::scalbln(arg,
// exp);
inline half scalbln(half arg, long exp)
    return functions::scalbln(arg, exp);
inline half scalbln(expr arg, long exp)
    return functions::scalbln(arg, exp);

/// Extract exponent.
/// \param arg number to query
/// \return floating point exponent
/// \retval FP_ILOGB0 for zero
/// \retval FP_ILOGBNAN for NaN
/// \retval MAX_INT for infinity
//		template<typename T> typename enable<int,T>::type ilogb(T arg) { return functions::ilogb(arg); }
inline int ilogb(half arg)
    return functions::ilogb(arg);
inline int ilogb(expr arg)
    return functions::ilogb(arg);

/// Extract exponent.
/// \param arg number to query
/// \return floating point exponent
//		template<typename T> typename enable<half,T>::type logb(T arg) { return functions::logb(arg); }
inline half logb(half arg)
    return functions::logb(arg);
inline half logb(expr arg)
    return functions::logb(arg);

/// Next representable value.
/// \param from value to compute next representable value for
/// \param to direction towards which to compute next value
/// \return next representable value after \a from in direction towards \a to
//		template<typename T,typename U> typename enable<half,T,U>::type nextafter(T from, U to) { return
// functions::nextafter(from, to); }
inline half nextafter(half from, half to)
    return functions::nextafter(from, to);
inline half nextafter(half from, expr to)
    return functions::nextafter(from, to);
inline half nextafter(expr from, half to)
    return functions::nextafter(from, to);
inline half nextafter(expr from, expr to)
    return functions::nextafter(from, to);

/// Next representable value.
/// \param from value to compute next representable value for
/// \param to direction towards which to compute next value
/// \return next representable value after \a from in direction towards \a to
//		template<typename T> typename enable<half,T>::type nexttoward(T from, long double to) { return
// functions::nexttoward(from, to); }
inline half nexttoward(half from, long double to)
    return functions::nexttoward(from, to);
inline half nexttoward(expr from, long double to)
    return functions::nexttoward(from, to);

/// Take sign.
/// \param x value to change sign for
/// \param y value to take sign from
/// \return value equal to \a x in magnitude and to \a y in sign
//		template<typename T,typename U> typename enable<half,T,U>::type copysign(T x, U y) { return
// functions::copysign(x, y); }
inline half copysign(half x, half y)
    return functions::copysign(x, y);
inline half copysign(half x, expr y)
    return functions::copysign(x, y);
inline half copysign(expr x, half y)
    return functions::copysign(x, y);
inline half copysign(expr x, expr y)
    return functions::copysign(x, y);

/// \}
/// \name Floating point classification
/// \{

/// Classify floating point value.
/// \param arg number to classify
/// \retval FP_ZERO for positive and negative zero
/// \retval FP_SUBNORMAL for subnormal numbers
/// \retval FP_INFINITY for positive and negative infinity
/// \retval FP_NAN for NaNs
/// \retval FP_NORMAL for all other (normal) values
//		template<typename T> typename enable<int,T>::type fpclassify(T arg) { return functions::fpclassify(arg); }
inline int fpclassify(half arg)
    return functions::fpclassify(arg);
inline int fpclassify(expr arg)
    return functions::fpclassify(arg);

/// Check if finite number.
/// \param arg number to check
/// \retval true if neither infinity nor NaN
/// \retval false else
//		template<typename T> typename enable<bool,T>::type isfinite(T arg) { return functions::isfinite(arg); }
inline bool isfinite(half arg)
    return functions::isfinite(arg);
inline bool isfinite(expr arg)
    return functions::isfinite(arg);

/// Check for infinity.
/// \param arg number to check
/// \retval true for positive or negative infinity
/// \retval false else
//		template<typename T> typename enable<bool,T>::type isinf(T arg) { return functions::isinf(arg); }
inline bool isinf(half arg)
    return functions::isinf(arg);
inline bool isinf(expr arg)
    return functions::isinf(arg);

/// Check for NaN.
/// \param arg number to check
/// \retval true for NaNs
/// \retval false else
//		template<typename T> typename enable<bool,T>::type isnan(T arg) { return functions::isnan(arg); }
inline bool isnan(half arg)
    return functions::isnan(arg);
inline bool isnan(expr arg)
    return functions::isnan(arg);

/// Check if normal number.
/// \param arg number to check
/// \retval true if normal number
/// \retval false if either subnormal, zero, infinity or NaN
//		template<typename T> typename enable<bool,T>::type isnormal(T arg) { return functions::isnormal(arg); }
inline bool isnormal(half arg)
    return functions::isnormal(arg);
inline bool isnormal(expr arg)
    return functions::isnormal(arg);

/// Check sign.
/// \param arg number to check
/// \retval true for negative number
/// \retval false for positive number
//		template<typename T> typename enable<bool,T>::type signbit(T arg) { return functions::signbit(arg); }
inline bool signbit(half arg)
    return functions::signbit(arg);
inline bool signbit(expr arg)
    return functions::signbit(arg);

/// \}
/// \name Comparison
/// \{

/// Comparison for greater than.
/// \param x first operand
/// \param y second operand
/// \retval true if \a x greater than \a y
/// \retval false else
//		template<typename T,typename U> typename enable<bool,T,U>::type isgreater(T x, U y) { return
// functions::isgreater(x, y); }
inline bool isgreater(half x, half y)
    return functions::isgreater(x, y);
inline bool isgreater(half x, expr y)
    return functions::isgreater(x, y);
inline bool isgreater(expr x, half y)
    return functions::isgreater(x, y);
inline bool isgreater(expr x, expr y)
    return functions::isgreater(x, y);

/// Comparison for greater equal.
/// \param x first operand
/// \param y second operand
/// \retval true if \a x greater equal \a y
/// \retval false else
//		template<typename T,typename U> typename enable<bool,T,U>::type isgreaterequal(T x, U y) { return
// functions::isgreaterequal(x, y); }
inline bool isgreaterequal(half x, half y)
    return functions::isgreaterequal(x, y);
inline bool isgreaterequal(half x, expr y)
    return functions::isgreaterequal(x, y);
inline bool isgreaterequal(expr x, half y)
    return functions::isgreaterequal(x, y);
inline bool isgreaterequal(expr x, expr y)
    return functions::isgreaterequal(x, y);

/// Comparison for less than.
/// \param x first operand
/// \param y second operand
/// \retval true if \a x less than \a y
/// \retval false else
//		template<typename T,typename U> typename enable<bool,T,U>::type isless(T x, U y) { return functions::isless(x,
// y);
inline bool isless(half x, half y)
    return functions::isless(x, y);
inline bool isless(half x, expr y)
    return functions::isless(x, y);
inline bool isless(expr x, half y)
    return functions::isless(x, y);
inline bool isless(expr x, expr y)
    return functions::isless(x, y);

/// Comparison for less equal.
/// \param x first operand
/// \param y second operand
/// \retval true if \a x less equal \a y
/// \retval false else
//		template<typename T,typename U> typename enable<bool,T,U>::type islessequal(T x, U y) { return
// functions::islessequal(x, y); }
inline bool islessequal(half x, half y)
    return functions::islessequal(x, y);
inline bool islessequal(half x, expr y)
    return functions::islessequal(x, y);
inline bool islessequal(expr x, half y)
    return functions::islessequal(x, y);
inline bool islessequal(expr x, expr y)
    return functions::islessequal(x, y);

/// Comarison for less or greater.
/// \param x first operand
/// \param y second operand
/// \retval true if either less or greater
/// \retval false else
//		template<typename T,typename U> typename enable<bool,T,U>::type islessgreater(T x, U y) { return
// functions::islessgreater(x, y); }
inline bool islessgreater(half x, half y)
    return functions::islessgreater(x, y);
inline bool islessgreater(half x, expr y)
    return functions::islessgreater(x, y);
inline bool islessgreater(expr x, half y)
    return functions::islessgreater(x, y);
inline bool islessgreater(expr x, expr y)
    return functions::islessgreater(x, y);

/// Check if unordered.
/// \param x first operand
/// \param y second operand
/// \retval true if unordered (one or two NaN operands)
/// \retval false else
//		template<typename T,typename U> typename enable<bool,T,U>::type isunordered(T x, U y) { return
// functions::isunordered(x, y); }
inline bool isunordered(half x, half y)
    return functions::isunordered(x, y);
inline bool isunordered(half x, expr y)
    return functions::isunordered(x, y);
inline bool isunordered(expr x, half y)
    return functions::isunordered(x, y);
inline bool isunordered(expr x, expr y)
    return functions::isunordered(x, y);

/// \name Casting
/// \{

/// Cast to or from half-precision floating point number.
/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted
/// directly using the given rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do.
/// It uses the default rounding mode.
/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types
/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler
/// error and casting between [half](\ref half_float::half)s is just a no-op.
/// \tparam T destination type (half or built-in arithmetic type)
/// \tparam U source type (half or built-in arithmetic type)
/// \param arg value to cast
/// \return \a arg converted to destination type
template <typename T, typename U>
T half_cast(U arg)
    return half_caster<T, U>::cast(arg);

/// Cast to or from half-precision floating point number.
/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted
/// directly using the given rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do.
/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types
/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler
/// error and casting between [half](\ref half_float::half)s is just a no-op.
/// \tparam T destination type (half or built-in arithmetic type)
/// \tparam R rounding mode to use.
/// \tparam U source type (half or built-in arithmetic type)
/// \param arg value to cast
/// \return \a arg converted to destination type
template <typename T, std::float_round_style R, typename U>
T half_cast(U arg)
    return half_caster<T, U, R>::cast(arg);
/// \}
} // namespace detail

using detail::operator==;
using detail::operator!=;
using detail::operator<;
using detail::operator>;
using detail::operator<=;
using detail::operator>=;
using detail::operator+;
using detail::operator-;
using detail::operator*;
using detail::operator/;
using detail::operator<<;
using detail::operator>>;

using detail::abs;
using detail::acos;
using detail::acosh;
using detail::asin;
using detail::asinh;
using detail::atan;
using detail::atan2;
using detail::atanh;
using detail::cbrt;
using detail::ceil;
using detail::cos;
using detail::cosh;
using detail::erf;
using detail::erfc;
using detail::exp;
using detail::exp2;
using detail::expm1;
using detail::fabs;
using detail::fdim;
using detail::floor;
using detail::fma;
using detail::fmax;
using detail::fmin;
using detail::fmod;
using detail::hypot;
using detail::lgamma;
using detail::log;
using detail::log10;
using detail::log1p;
using detail::log2;
using detail::lrint;
using detail::lround;
using detail::nanh;
using detail::nearbyint;
using detail::pow;
using detail::remainder;
using detail::remquo;
using detail::rint;
using detail::round;
using detail::sin;
using detail::sinh;
using detail::sqrt;
using detail::tan;
using detail::tanh;
using detail::tgamma;
using detail::trunc;
using detail::llrint;
using detail::llround;
using detail::copysign;
using detail::fpclassify;
using detail::frexp;
using detail::ilogb;
using detail::isfinite;
using detail::isgreater;
using detail::isgreaterequal;
using detail::isinf;
using detail::isless;
using detail::islessequal;
using detail::islessgreater;
using detail::isnan;
using detail::isnormal;
using detail::isunordered;
using detail::ldexp;
using detail::logb;
using detail::modf;
using detail::nextafter;
using detail::nexttoward;
using detail::scalbln;
using detail::scalbn;
using detail::signbit;

using detail::half_cast;
} // namespace half_float

/// Extensions to the C++ standard library.
namespace std
/// Numeric limits for half-precision floats.
/// Because of the underlying single-precision implementation of many operations, it inherits some properties from
/// `std::numeric_limits<float>`.
template <>
class numeric_limits<half_float::half> : public numeric_limits<float>
    /// Supports signed values.
    static HALF_CONSTEXPR_CONST bool is_signed = true;

    /// Is not exact.
    static HALF_CONSTEXPR_CONST bool is_exact = false;

    /// Doesn't provide modulo arithmetic.
    static HALF_CONSTEXPR_CONST bool is_modulo = false;

    /// IEEE conformant.
    static HALF_CONSTEXPR_CONST bool is_iec559 = true;

    /// Supports infinity.
    static HALF_CONSTEXPR_CONST bool has_infinity = true;

    /// Supports quiet NaNs.
    static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true;

    /// Supports subnormal values.
    static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present;

    /// Rounding mode.
    /// Due to the mix of internal single-precision computations (using the rounding mode of the underlying
    /// single-precision implementation) with the rounding mode of the single-to-half conversions, the actual rounding
    /// mode might be `std::round_indeterminate` if the default half-precision rounding mode doesn't match the
    /// single-precision rounding mode.
    static HALF_CONSTEXPR_CONST float_round_style round_style
        = (std::numeric_limits<float>::round_style == half_float::half::round_style) ? half_float::half::round_style
                                                                                     : round_indeterminate;

    /// Significant digits.
    static HALF_CONSTEXPR_CONST int digits = 11;

    /// Significant decimal digits.
    static HALF_CONSTEXPR_CONST int digits10 = 3;

    /// Required decimal digits to represent all possible values.
    static HALF_CONSTEXPR_CONST int max_digits10 = 5;

    /// Number base.
    static HALF_CONSTEXPR_CONST int radix = 2;

    /// One more than smallest exponent.
    static HALF_CONSTEXPR_CONST int min_exponent = -13;

    /// Smallest normalized representable power of 10.
    static HALF_CONSTEXPR_CONST int min_exponent10 = -4;

    /// One more than largest exponent
    static HALF_CONSTEXPR_CONST int max_exponent = 16;

    /// Largest finitely representable power of 10.
    static HALF_CONSTEXPR_CONST int max_exponent10 = 4;

    /// Smallest positive normal value.
    static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW
        return half_float::half(half_float::detail::binary, 0x0400);

    /// Smallest finite value.
    static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW
        return half_float::half(half_float::detail::binary, 0xFBFF);

    /// Largest finite value.
    static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW
        return half_float::half(half_float::detail::binary, 0x7BFF);

    /// Difference between one and next representable value.
    static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW
        return half_float::half(half_float::detail::binary, 0x1400);

    /// Maximum rounding error.
    static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW
        return half_float::half(half_float::detail::binary, (round_style == std::round_to_nearest) ? 0x3800 : 0x3C00);

    /// Positive infinity.
    static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW
        return half_float::half(half_float::detail::binary, 0x7C00);

    /// Quiet NaN.
    static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW
        return half_float::half(half_float::detail::binary, 0x7FFF);

    /// Signalling NaN.
    static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW
        return half_float::half(half_float::detail::binary, 0x7DFF);

    /// Smallest positive subnormal value.
    static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW
        return half_float::half(half_float::detail::binary, 0x0001);

/// Hash function for half-precision floats.
/// This is only defined if C++11 `std::hash` is supported and enabled.
template <>
struct hash<half_float::half> //: unary_function<half_float::half,size_t>
    /// Type of function argument.
    typedef half_float::half argument_type;

    /// Function return type.
    typedef size_t result_type;

    /// Compute hash function.
    /// \param arg half to hash
    /// \return hash value
    result_type operator()(argument_type arg) const
        return hash<half_float::detail::uint16>()(static_cast<unsigned>(arg.data_) & -(arg.data_ != 0x8000));
} // namespace std

#pragma warning(pop)



 * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.

#include "logger.h"
#include "ErrorRecorder.h"
#include "logging.h"
using namespace nvinfer1;
SampleErrorRecorder gRecorder;
namespace sample
Logger gLogger{Logger::Severity::kINFO};
LogStreamConsumer gLogVerbose{LOG_VERBOSE(gLogger)};
LogStreamConsumer gLogInfo{LOG_INFO(gLogger)};
LogStreamConsumer gLogWarning{LOG_WARN(gLogger)};
LogStreamConsumer gLogError{LOG_ERROR(gLogger)};
LogStreamConsumer gLogFatal{LOG_FATAL(gLogger)};

void setReportableSeverity(Logger::Severity severity)
} // namespace sample


 * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.

#ifndef LOGGER_H
#define LOGGER_H

#include "logging.h"

class SampleErrorRecorder;
extern SampleErrorRecorder gRecorder;
namespace sample
extern Logger gLogger;
extern LogStreamConsumer gLogVerbose;
extern LogStreamConsumer gLogInfo;
extern LogStreamConsumer gLogWarning;
extern LogStreamConsumer gLogError;
extern LogStreamConsumer gLogFatal;

void setReportableSeverity(Logger::Severity severity);
} // namespace sample

#endif // LOGGER_H


 * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.


#include "NvInferRuntimeCommon.h"
#include "sampleOptions.h"
#include <cassert>
#include <ctime>
#include <iomanip>
#include <iostream>
#include <mutex>
#include <ostream>
#include <sstream>
#include <string>

namespace sample

using Severity = nvinfer1::ILogger::Severity;

class LogStreamConsumerBuffer : public std::stringbuf
    LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog)
        : mOutput(stream)
        , mPrefix(prefix)
        , mShouldLog(shouldLog)

    LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) noexcept
        : mOutput(other.mOutput)
        , mPrefix(other.mPrefix)
        , mShouldLog(other.mShouldLog)
    LogStreamConsumerBuffer(const LogStreamConsumerBuffer& other) = delete;
    LogStreamConsumerBuffer() = delete;
    LogStreamConsumerBuffer& operator=(const LogStreamConsumerBuffer&) = delete;
    LogStreamConsumerBuffer& operator=(LogStreamConsumerBuffer&&) = delete;

    ~LogStreamConsumerBuffer() override
        // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence
        // std::streambuf::pptr() gives a pointer to the current position of the output sequence
        // if the pointer to the beginning is not equal to the pointer to the current position,
        // call putOutput() to log the output to the stream
        if (pbase() != pptr())

    //! synchronizes the stream buffer and returns 0 on success
    //! synchronizing the stream buffer consists of inserting the buffer contents into the stream,
    //! resetting the buffer and flushing the stream
    int32_t sync() override
        return 0;

    void putOutput()
        if (mShouldLog)
            // prepend timestamp
            std::time_t timestamp = std::time(nullptr);
            tm* tm_local = std::localtime(&timestamp);
            mOutput << "[";
            mOutput << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/";
            mOutput << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/";
            mOutput << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-";
            mOutput << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":";
            mOutput << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":";
            mOutput << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] ";
            // std::stringbuf::str() gets the string contents of the buffer
            // insert the buffer contents pre-appended by the appropriate prefix into the stream
            mOutput << mPrefix << str();
        // set the buffer to empty
        // flush the stream

    void setShouldLog(bool shouldLog)
        mShouldLog = shouldLog;

    std::ostream& mOutput;
    std::string mPrefix;
    bool mShouldLog{};
}; // class LogStreamConsumerBuffer

//! \class LogStreamConsumerBase
//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer
class LogStreamConsumerBase
    LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog)
        : mBuffer(stream, prefix, shouldLog)

    std::mutex mLogMutex;
    LogStreamConsumerBuffer mBuffer;
}; // class LogStreamConsumerBase

//! \class LogStreamConsumer
//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages.
//!  Order of base classes is LogStreamConsumerBase and then std::ostream.
//!  This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field
//!  in LogStreamConsumer and then the address of the buffer is passed to std::ostream.
//!  This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream.
//!  Please do not change the order of the parent classes.
class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream
    //! \brief Creates a LogStreamConsumer which logs messages with level severity.
    //!  Reportable severity determines if the messages are severe enough to be logged.
    LogStreamConsumer(nvinfer1::ILogger::Severity reportableSeverity, nvinfer1::ILogger::Severity severity)
        : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity)
        , std::ostream(&mBuffer) // links the stream buffer with the stream
        , mShouldLog(severity <= reportableSeverity)
        , mSeverity(severity)

    LogStreamConsumer(LogStreamConsumer&& other) noexcept
        : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog)
        , std::ostream(&mBuffer) // links the stream buffer with the stream
        , mShouldLog(other.mShouldLog)
        , mSeverity(other.mSeverity)
    LogStreamConsumer(const LogStreamConsumer& other) = delete;
    LogStreamConsumer() = delete;
    ~LogStreamConsumer() = default;
    LogStreamConsumer& operator=(const LogStreamConsumer&) = delete;
    LogStreamConsumer& operator=(LogStreamConsumer&&) = delete;

    void setReportableSeverity(Severity reportableSeverity)
        mShouldLog = mSeverity <= reportableSeverity;

    std::mutex& getMutex()
        return mLogMutex;

    bool getShouldLog() const
        return mShouldLog;

    static std::ostream& severityOstream(Severity severity)
        return severity >= Severity::kINFO ? std::cout : std::cerr;

    static std::string severityPrefix(Severity severity)
        switch (severity)
        case Severity::kINTERNAL_ERROR: return "[F] ";
        case Severity::kERROR: return "[E] ";
        case Severity::kWARNING: return "[W] ";
        case Severity::kINFO: return "[I] ";
        case Severity::kVERBOSE: return "[V] ";
        default: assert(0); return "";

    bool mShouldLog;
    Severity mSeverity;
}; // class LogStreamConsumer

template <typename T>
LogStreamConsumer& operator<<(LogStreamConsumer& logger, const T& obj)
    if (logger.getShouldLog())
        std::lock_guard<std::mutex> guard(logger.getMutex());
        auto& os = static_cast<std::ostream&>(logger);
        os << obj;
    return logger;

//! Special handling std::endl
inline LogStreamConsumer& operator<<(LogStreamConsumer& logger, std::ostream& (*f)(std::ostream&) )
    if (logger.getShouldLog())
        std::lock_guard<std::mutex> guard(logger.getMutex());
        auto& os = static_cast<std::ostream&>(logger);
        os << f;
    return logger;

inline LogStreamConsumer& operator<<(LogStreamConsumer& logger, const nvinfer1::Dims& dims)
    if (logger.getShouldLog())
        std::lock_guard<std::mutex> guard(logger.getMutex());
        auto& os = static_cast<std::ostream&>(logger);
        for (int32_t i = 0; i < dims.nbDims; ++i)
            os << (i ? "x" : "") << dims.d[i];
    return logger;

//! \class Logger
//! \brief Class which manages logging of TensorRT tools and samples
//! \details This class provides a common interface for TensorRT tools and samples to log information to the console,
//! and supports logging two types of messages:
//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal)
//! - Test pass/fail messages
//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is
//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location.
//! In the future, this class could be extended to support dumping test results to a file in some standard format
//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run).
//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger
//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT
//! library and messages coming from the sample.
//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the
//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger
//! object.
class Logger : public nvinfer1::ILogger
    explicit Logger(Severity severity = Severity::kWARNING)
        : mReportableSeverity(severity)

    //! \enum TestResult
    //! \brief Represents the state of a given test
    enum class TestResult
        kRUNNING, //!< The test is running
        kPASSED,  //!< The test passed
        kFAILED,  //!< The test failed
        kWAIVED   //!< The test was waived

    //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger
    //! \return The nvinfer1::ILogger associated with this Logger
    //! TODO Once all samples are updated to use this method to register the logger with TensorRT,
    //! we can eliminate the inheritance of Logger from ILogger
    nvinfer1::ILogger& getTRTLogger() noexcept
        return *this;

    //! \brief Implementation of the nvinfer1::ILogger::log() virtual method
    //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
    //! inheritance from nvinfer1::ILogger
    void log(Severity severity, const char* msg) noexcept override
        LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;

    //! \brief Method for controlling the verbosity of logging output
    //! \param severity The logger will only emit messages that have severity of this level or higher.
    void setReportableSeverity(Severity severity) noexcept
        mReportableSeverity = severity;

    //! \brief Opaque handle that holds logging information for a particular test
    //! This object is an opaque handle to information used by the Logger to print test results.
    //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used
    //! with Logger::reportTest{Start,End}().
    class TestAtom
        TestAtom(TestAtom&&) = default;

        friend class Logger;

        TestAtom(bool started, const std::string& name, const std::string& cmdline)
            : mStarted(started)
            , mName(name)
            , mCmdline(cmdline)

        bool mStarted;
        std::string mName;
        std::string mCmdline;

    //! \brief Define a test for logging
    //! \param[in] name The name of the test.  This should be a string starting with
    //!                  "TensorRT" and containing dot-separated strings containing
    //!                  the characters [A-Za-z0-9_].
    //!                  For example, "TensorRT.sample_googlenet"
    //! \param[in] cmdline The command line used to reproduce the test
    //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
    static TestAtom defineTest(const std::string& name, const std::string& cmdline)
        return TestAtom(false, name, cmdline);

    //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments
    //!        as input
    //! \param[in] name The name of the test
    //! \param[in] argc The number of command-line arguments
    //! \param[in] argv The array of command-line arguments (given as C strings)
    //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
    static TestAtom defineTest(const std::string& name, int32_t argc, char const* const* argv)
        // Append TensorRT version as info
        const std::string vname = name + " [TensorRT v" + std::to_string(NV_TENSORRT_VERSION) + "]";
        auto cmdline = genCmdlineString(argc, argv);
        return defineTest(vname, cmdline);

    //! \brief Report that a test has started.
    //! \pre reportTestStart() has not been called yet for the given testAtom
    //! \param[in] testAtom The handle to the test that has started
    static void reportTestStart(TestAtom& testAtom)
        reportTestResult(testAtom, TestResult::kRUNNING);
        testAtom.mStarted = true;

    //! \brief Report that a test has ended.
    //! \pre reportTestStart() has been called for the given testAtom
    //! \param[in] testAtom The handle to the test that has ended
    //! \param[in] result The result of the test. Should be one of TestResult::kPASSED,
    //!                   TestResult::kFAILED, TestResult::kWAIVED
    static void reportTestEnd(TestAtom const& testAtom, TestResult result)
        assert(result != TestResult::kRUNNING);
        reportTestResult(testAtom, result);

    static int32_t reportPass(TestAtom const& testAtom)
        reportTestEnd(testAtom, TestResult::kPASSED);
        return EXIT_SUCCESS;

    static int32_t reportFail(TestAtom const& testAtom)
        reportTestEnd(testAtom, TestResult::kFAILED);
        return EXIT_FAILURE;

    static int32_t reportWaive(TestAtom const& testAtom)
        reportTestEnd(testAtom, TestResult::kWAIVED);
        return EXIT_SUCCESS;

    static int32_t reportTest(TestAtom const& testAtom, bool pass)
        return pass ? reportPass(testAtom) : reportFail(testAtom);

    Severity getReportableSeverity() const
        return mReportableSeverity;

    //! \brief returns an appropriate string for prefixing a log message with the given severity
    static const char* severityPrefix(Severity severity)
        switch (severity)
        case Severity::kINTERNAL_ERROR: return "[F] ";
        case Severity::kERROR: return "[E] ";
        case Severity::kWARNING: return "[W] ";
        case Severity::kINFO: return "[I] ";
        case Severity::kVERBOSE: return "[V] ";
        default: assert(0); return "";

    //! \brief returns an appropriate string for prefixing a test result message with the given result
    static const char* testResultString(TestResult result)
        switch (result)
        case TestResult::kRUNNING: return "RUNNING";
        case TestResult::kPASSED: return "PASSED";
        case TestResult::kFAILED: return "FAILED";
        case TestResult::kWAIVED: return "WAIVED";
        default: assert(0); return "";

    //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity
    static std::ostream& severityOstream(Severity severity)
        return severity >= Severity::kINFO ? std::cout : std::cerr;

    //! \brief method that implements logging test results
    static void reportTestResult(TestAtom const& testAtom, TestResult result)
        severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "
                                         << testAtom.mCmdline << std::endl;

    //! \brief generate a command line string from the given (argc, argv) values
    static std::string genCmdlineString(int32_t argc, char const* const* argv)
        std::stringstream ss;
        for (int32_t i = 0; i < argc; i++)
            if (i > 0)
                ss << " ";
            ss << argv[i];
        return ss.str();

    Severity mReportableSeverity;
}; // class Logger

//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE
//! Example usage:
//!     LOG_VERBOSE(logger) << "hello world" << std::endl;
inline LogStreamConsumer LOG_VERBOSE(const Logger& logger)
    return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE);

//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO
//! Example usage:
//!     LOG_INFO(logger) << "hello world" << std::endl;
inline LogStreamConsumer LOG_INFO(const Logger& logger)
    return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO);

//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING
//! Example usage:
//!     LOG_WARN(logger) << "hello world" << std::endl;
inline LogStreamConsumer LOG_WARN(const Logger& logger)
    return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING);

//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR
//! Example usage:
//!     LOG_ERROR(logger) << "hello world" << std::endl;
inline LogStreamConsumer LOG_ERROR(const Logger& logger)
    return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR);

//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR
//!        ("fatal" severity)
//! Example usage:
//!     LOG_FATAL(logger) << "hello world" << std::endl;
inline LogStreamConsumer LOG_FATAL(const Logger& logger)
    return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR);
} // anonymous namespace
} // namespace sample


#include "engine.h"
#include <opencv2/opencv.hpp>
#include <chrono>
#include <iostream>
#include "qdebug.h"

typedef std::chrono::high_resolution_clock Clock;
template <class ForwardIterator>
inline static size_t argmax(ForwardIterator first, ForwardIterator last)
  return std::distance(first, std::max_element(first, last));
int main()
  Options options;
  // TODO: Specify your precision here.
  options.FP16 = false;
  // TODO: Specify your input dimension here.
  options.inputDimension = {3,48,320}; // Modify to {3,32,320} when using ppocrv2
  // TODO: Specify your character_dict here.
  std::string label_path = "D:\\projects\\ocr\\data\\ppocr_keys_v1.txt";
  // TODO: Specify your test image here.
  const std::string inputImage = "D:\\projects\\ocr\\data\\test.jpg";
  // TODO: Specify your model here.
  const std::string onnxModelpath = "D:\\projects\\ocr\\data\\modelv3.onnx"; // Modify to "../data/modelv2.onnx" when using ppocrv2

  std::vector<std::string> label_list_ = ReadDict(label_path);
  Engine engine(options);

  bool succ =;
  if (!succ)
    throw std::runtime_error("Unable to build TRT engine.");

  succ = engine.loadNetwork();
  if (!succ)
    throw std::runtime_error("Unable to load TRT engine.");

  std::vector<cv::Mat> images;
  images.push_back(engine.preprocessImg(inputImage)); // Batchsize = 1

  // Do inference
  std::vector<std::vector<float>> featureVectors;
  int outsize;
  auto t1 = Clock::now(); // Discard the first inference time as it takes longer

  outsize = engine.runInference(images, featureVectors); // featureVectors[0] size: [W/4, 6625], in default character_dict.

  // Postprocess
  std::pair<std::vector<std::string>, double> res;
  std::vector<std::string> str_res;
  int argmax_idx;
  int last_index = 0;
  float score = 0.f;
  int count = 0;
  float max_value = 0.0f;
  int m = 0;

  // predict_shape = (1, 80, 6625) in the default model
  // 6625 = 6623 + 2, the length of character_dict is 6623 and 2 character for blank and space.
  std::vector<int> predict_shape = {1, engine.outputDims.d[1], engine.outputDims.d[2]};
  // CTC decode
  // Reference
  for (int n = 0; n < predict_shape[1]; n++)
    argmax_idx =
        int(argmax(&featureVectors[0][(m * predict_shape[1] + n) * predict_shape[2]],
                   &featureVectors[0][(m * predict_shape[1] + n + 1) * predict_shape[2]]));
    max_value =
        float(*std::max_element(&featureVectors[0][(m * predict_shape[1] + n) * predict_shape[2]],
                                &featureVectors[0][(m * predict_shape[1] + n + 1) * predict_shape[2]]));
    if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index)))
      score += max_value;
      count += 1;
      str_res.push_back(label_list_[argmax_idx - 1]);
      // I replace "label_list_[argmax_idx]" with "label_list_[argmax_idx - 1]" in my version, otherwise the result can't match.
      // I think this is decided by "use_space_char = True/False" in PaddleOCR config file.
    last_index = argmax_idx;
  score /= count;

  // Print result
  QString qstr = "";
  for (int i = 0; i < str_res.size(); i++)
      qstr += QString::fromStdString(str_res[i]);
  std::cout << "\tscore: " << std::setprecision(16) << score << std::endl;

  res.first = str_res;
  res.second = score;

  auto t2 = Clock::now();
  double totalTime = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count();

  // This time is a little more than actual average inference time.
  std::cout << "Success! Inference time: " << totalTime / static_cast<float>(images.size()) << " ms, for batch size of: " << images.size() << std::endl;
  return 0;


 * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.


#include "cuda_runtime.h"
#include "NvInferRuntimeCommon.h"
#include <cstdlib>
#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>
#include <numeric>

// For loadLibrary
#ifdef _MSC_VER
// Needed so that the max/min definitions in windows.h do not conflict with std::max/min.
#define NOMINMAX
#include <windows.h>
#include <dlfcn.h>

#undef CHECK
#define CHECK(status)                                                                                                  \
    do                                                                                                                 \
    {                                                                                                                  \
        auto ret = (status);                                                                                           \
        if (ret != 0)                                                                                                  \
        {                                                                                                              \
            std::cerr << "Cuda failure: " << ret << std::endl;                                                         \
            abort();                                                                                                   \
        }                                                                                                              \
    } while (0)

#define SAFE_ASSERT(condition)                                                   \
    do                                                                      \
    {                                                                       \
        if (!(condition))                                                   \
        {                                                                   \
            std::cerr << "Assertion failure: " << #condition << std::endl;  \
            abort();                                                        \
        }                                                                   \
    } while (0)

namespace samplesCommon
template <typename T>
inline std::shared_ptr<T> infer_object(T* obj)
    if (!obj)
        throw std::runtime_error("Failed to create object");
    return std::shared_ptr<T>(obj);

inline uint32_t elementSize(nvinfer1::DataType t)
    switch (t)
    case nvinfer1::DataType::kINT32:
    case nvinfer1::DataType::kFLOAT: return 4;
    case nvinfer1::DataType::kHALF: return 2;
    case nvinfer1::DataType::kINT8: return 1;
    case nvinfer1::DataType::kUINT8: return 1;
    case nvinfer1::DataType::kBOOL: return 1;
    return 0;

template <typename A, typename B>
inline A divUp(A x, B n)
    return (x + n - 1) / n;

inline int64_t volume(nvinfer1::Dims const& d)
    return std::accumulate(d.d, d.d + d.nbDims, int64_t{1}, std::multiplies<int64_t>{});

// Return m rounded up to nearest multiple of n
template <typename T>
inline T roundUp(T m, T n)
    return ((m + n - 1) / n) * n;

//! comps is the number of components in a vector. Ignored if vecDim < 0.
inline int64_t volume(nvinfer1::Dims dims, int32_t vecDim, int32_t comps, int32_t batch)
    if (vecDim >= 0)
        dims.d[vecDim] = roundUp(dims.d[vecDim], comps);
    return samplesCommon::volume(dims) * std::max(batch, 1);

//! \class TrtCudaGraphSafe
//! \brief Managed CUDA graph
class TrtCudaGraphSafe
    explicit TrtCudaGraphSafe() = default;

    TrtCudaGraphSafe(const TrtCudaGraphSafe&) = delete;

    TrtCudaGraphSafe& operator=(const TrtCudaGraphSafe&) = delete;

    TrtCudaGraphSafe(TrtCudaGraphSafe&&) = delete;

    TrtCudaGraphSafe& operator=(TrtCudaGraphSafe&&) = delete;

        if (mGraphExec)

    void beginCapture(cudaStream_t& stream)
        // cudaStreamCaptureModeGlobal is the only allowed mode in SAFE CUDA
        CHECK(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));

    bool launch(cudaStream_t& stream)
        return cudaGraphLaunch(mGraphExec, stream) == cudaSuccess;

    void endCapture(cudaStream_t& stream)
        CHECK(cudaStreamEndCapture(stream, &mGraph));
        CHECK(cudaGraphInstantiate(&mGraphExec, mGraph, nullptr, nullptr, 0));

    void endCaptureOnError(cudaStream_t& stream)
        // There are two possibilities why stream capture would fail:
        // (1) stream is in cudaErrorStreamCaptureInvalidated state.
        // (2) TRT reports a failure.
        // In case (1), the returning mGraph should be nullptr.
        // In case (2), the returning mGraph is not nullptr, but it should not be used.
        const auto ret = cudaStreamEndCapture(stream, &mGraph);
        if (ret == cudaErrorStreamCaptureInvalidated)
            SAFE_ASSERT(mGraph == nullptr);
            SAFE_ASSERT(ret == cudaSuccess);
            SAFE_ASSERT(mGraph != nullptr);
            mGraph = nullptr;
        // Clean up any CUDA error.
        sample::gLogError << "The CUDA graph capture on the stream has failed." << std::endl;

    cudaGraph_t mGraph{};
    cudaGraphExec_t mGraphExec{};

inline void safeLoadLibrary(const std::string& path)
#ifdef _MSC_VER
    void* handle = LoadLibrary(path.c_str());
    int32_t flags{RTLD_LAZY};
    void* handle = dlopen(path.c_str(), flags);
    if (handle == nullptr)
#ifdef _MSC_VER
        sample::gLogError << "Could not load plugin library: " << path << std::endl;
        sample::gLogError << "Could not load plugin library: " << path << ", due to: " << dlerror() << std::endl;

inline std::vector<std::string> safeSplitString(std::string str, char delimiter = ',')
    std::vector<std::string> splitVect;
    std::stringstream ss(str);
    std::string substr;

    while (ss.good())
        getline(ss, substr, delimiter);
    return splitVect;

} // namespace samplesCommon



 * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.


#include <algorithm>
#include <array>
#include <iostream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "NvInfer.h"

namespace sample

// Build default params
constexpr int32_t maxBatchNotProvided{0};
constexpr int32_t defaultMinTiming{1};
constexpr int32_t defaultAvgTiming{8};

// System default params
constexpr int32_t defaultDevice{0};

// Inference default params
constexpr int32_t defaultBatch{1};
constexpr int32_t batchNotProvided{0};
constexpr int32_t defaultStreams{1};
constexpr int32_t defaultIterations{10};
constexpr float defaultWarmUp{200.F};
constexpr float defaultDuration{3.F};
constexpr float defaultSleep{};
constexpr float defaultIdle{};
constexpr float defaultPersistentCacheRatio{0};

// Reporting default params
constexpr int32_t defaultAvgRuns{10};
constexpr std::array<float, 3> defaultPercentiles{90, 95, 99};

enum class PrecisionConstraints

enum class ModelFormat

enum class SparsityFlag

enum class TimingCacheMode

using Arguments = std::unordered_multimap<std::string, std::string>;

using IOFormat = std::pair<nvinfer1::DataType, nvinfer1::TensorFormats>;

using ShapeRange = std::array<std::vector<int32_t>, nvinfer1::EnumMax<nvinfer1::OptProfileSelector>()>;

using LayerPrecisions = std::unordered_map<std::string, nvinfer1::DataType>;
using LayerOutputTypes = std::unordered_map<std::string, std::vector<nvinfer1::DataType>>;

struct Options
    virtual void parse(Arguments& arguments) = 0;

struct BaseModelOptions : public Options
    ModelFormat format{ModelFormat::kANY};
    std::string model;

    void parse(Arguments& arguments) override;

    static void help(std::ostream& out);

struct UffInput : public Options
    std::vector<std::pair<std::string, nvinfer1::Dims>> inputs;
    bool NHWC{false};

    void parse(Arguments& arguments) override;

    static void help(std::ostream& out);

struct ModelOptions : public Options
    BaseModelOptions baseModel;
    std::string prototxt;
    std::vector<std::string> outputs;
    UffInput uffInputs;

    void parse(Arguments& arguments) override;

    static void help(std::ostream& out);

struct BuildOptions : public Options
    int32_t maxBatch{maxBatchNotProvided};
    double workspace{-1.0};
    double dlaSRAM{-1.0};
    double dlaLocalDRAM{-1.0};
    double dlaGlobalDRAM{-1.0};
    int32_t minTiming{defaultMinTiming};
    int32_t avgTiming{defaultAvgTiming};
    bool tf32{true};
    bool fp16{false};
    bool int8{false};
    bool directIO{false};
    PrecisionConstraints precisionConstraints{PrecisionConstraints::kNONE};
    LayerPrecisions layerPrecisions;
    LayerOutputTypes layerOutputTypes;
    bool safe{false};
    bool consistency{false};
    bool restricted{false};
    bool buildOnly{false};
    bool save{false};
    bool load{false};
    bool refittable{false};
    bool heuristic{false};
    SparsityFlag sparsity{SparsityFlag::kDISABLE};
    nvinfer1::ProfilingVerbosity profilingVerbosity{nvinfer1::ProfilingVerbosity::kLAYER_NAMES_ONLY};
    std::string engine;
    std::string calibration;
    std::unordered_map<std::string, ShapeRange> shapes;
    std::unordered_map<std::string, ShapeRange> shapesCalib;
    std::vector<IOFormat> inputFormats;
    std::vector<IOFormat> outputFormats;
    nvinfer1::TacticSources enabledTactics{0};
    nvinfer1::TacticSources disabledTactics{0};
    TimingCacheMode timingCacheMode{TimingCacheMode::kLOCAL};
    std::string timingCacheFile{};
    // C++11 does not automatically generate hash function for enum class.
    // Use int32_t to support C++11 compilers.
    std::unordered_map<int32_t, bool> previewFeatures;
    void parse(Arguments& arguments) override;

    static void help(std::ostream& out);

struct SystemOptions : public Options
    int32_t device{defaultDevice};
    int32_t DLACore{-1};
    bool fallback{false};
    std::vector<std::string> plugins;

    void parse(Arguments& arguments) override;

    static void help(std::ostream& out);

struct InferenceOptions : public Options
    int32_t batch{batchNotProvided};
    int32_t iterations{defaultIterations};
    int32_t streams{defaultStreams};
    float warmup{defaultWarmUp};
    float duration{defaultDuration};
    float sleep{defaultSleep};
    float idle{defaultIdle};
    float persistentCacheRatio{defaultPersistentCacheRatio};
    bool overlap{true};
    bool skipTransfers{false};
    bool useManaged{false};
    bool spin{false};
    bool threads{false};
    bool graph{false};
    bool rerun{false};
    bool timeDeserialize{false};
    bool timeRefit{false};
    std::unordered_map<std::string, std::string> inputs;
    std::unordered_map<std::string, std::vector<int32_t>> shapes;
    nvinfer1::ProfilingVerbosity nvtxVerbosity{nvinfer1::ProfilingVerbosity::kLAYER_NAMES_ONLY};

    void parse(Arguments& arguments) override;

    static void help(std::ostream& out);

struct ReportingOptions : public Options
    bool verbose{false};
    int32_t avgs{defaultAvgRuns};
    std::vector<float> percentiles{defaultPercentiles.begin(), defaultPercentiles.end()};
    bool refit{false};
    bool output{false};
    bool profile{false};
    bool layerInfo{false};
    std::string exportTimes;
    std::string exportOutput;
    std::string exportProfile;
    std::string exportLayerInfo;

    void parse(Arguments& arguments) override;

    static void help(std::ostream& out);

struct SafeBuilderOptions : public Options
    std::string serialized{};
    std::string onnxModelFile{};
    bool help{false};
    bool verbose{false};
    std::vector<IOFormat> inputFormats;
    std::vector<IOFormat> outputFormats;
    bool int8{false};
    std::string calibFile{};
    std::vector<std::string> plugins;
    bool consistency{false};
    bool standard{false};
    TimingCacheMode timingCacheMode{TimingCacheMode::kLOCAL};
    std::string timingCacheFile{};
    SparsityFlag sparsity{SparsityFlag::kDISABLE};
    int32_t minTiming{defaultMinTiming};
    int32_t avgTiming{defaultAvgTiming};

    void parse(Arguments& arguments) override;

    static void printHelp(std::ostream& out);

struct AllOptions : public Options
    ModelOptions model;
    BuildOptions build;
    SystemOptions system;
    InferenceOptions inference;
    ReportingOptions reporting;
    bool helps{false};

    void parse(Arguments& arguments) override;

    static void help(std::ostream& out);

struct TaskInferenceOptions : public Options
    std::string engine;
    int32_t device{defaultDevice};
    int32_t DLACore{-1};
    int32_t batch{batchNotProvided};
    bool graph{false};
    float persistentCacheRatio{defaultPersistentCacheRatio};
    void parse(Arguments& arguments) override;
    static void help(std::ostream& out);

Arguments argsToArgumentsMap(int32_t argc, char* argv[]);

bool parseHelp(Arguments& arguments);

void helpHelp(std::ostream& out);

// Functions to print options

std::ostream& operator<<(std::ostream& os, const BaseModelOptions& options);

std::ostream& operator<<(std::ostream& os, const UffInput& input);

std::ostream& operator<<(std::ostream& os, const IOFormat& format);

std::ostream& operator<<(std::ostream& os, const ShapeRange& dims);

std::ostream& operator<<(std::ostream& os, const ModelOptions& options);

std::ostream& operator<<(std::ostream& os, const BuildOptions& options);

std::ostream& operator<<(std::ostream& os, const SystemOptions& options);

std::ostream& operator<<(std::ostream& os, const InferenceOptions& options);

std::ostream& operator<<(std::ostream& os, const ReportingOptions& options);

std::ostream& operator<<(std::ostream& os, const AllOptions& options);

std::ostream& operator<<(std::ostream& os, const SafeBuilderOptions& options);

inline std::ostream& operator<<(std::ostream& os, const nvinfer1::Dims& dims)
    for (int32_t i = 0; i < dims.nbDims; ++i)
        os << (i ? "x" : "") << dims.d[i];
    return os;
inline std::ostream& operator<<(std::ostream& os, const nvinfer1::WeightsRole role)
    switch (role)
    case nvinfer1::WeightsRole::kKERNEL:
        os << "Kernel";
    case nvinfer1::WeightsRole::kBIAS:
        os << "Bias";
    case nvinfer1::WeightsRole::kSHIFT:
        os << "Shift";
    case nvinfer1::WeightsRole::kSCALE:
        os << "Scale";
    case nvinfer1::WeightsRole::kCONSTANT:
        os << "Constant";
    case nvinfer1::WeightsRole::kANY:
        os << "Any";

    return os;

inline std::ostream& operator<<(std::ostream& os, const std::vector<int32_t>& vec)
    for (int32_t i = 0, e = static_cast<int32_t>(vec.size()); i < e; ++i)
        os << (i ? "x" : "") << vec[i];
    return os;

} // namespace sample


