TensorRT - 扩展TensorRT C++API的模型输入维度,增加Dims5,Dims6,Dims7,Dims8

1 TensorRT C++ API支持的模型输入维度

在TensorRT 7.0及以上版本,我们通常使用以下语句指定输入维度:

    const std::string input_name = "input";
    const std::string output_name = "output";
    const int inputIndex = m_TensorRT_Engine->getBindingIndex(input_name.c_str());
    const int outputIndex = m_TensorRT_Engine->getBindingIndex(output_name.c_str());
    m_TensorRT_Context->setBindingDimensions(inputIndex, Dims3(3, 100, 20));

其中Dims3代表该深度学习模型的输入Tensor的维度为三维tensor,shape为(3,100,20)

一般的深度学习模型,一般的输入维度为(C,H,W),这种输入的维度数据为三维tensor。

另外TensorRT C++ API最高支持Dims4,用于支持4维tensor数据的模型输入。但是随着深度学习框架目前发展的越来越复杂,更多的深度的学习模型需要5维,6维甚至更高维度的tensor作为网络输入,那么如何在现有的TensorRT API去扩展更高维度的输入tensor以满足我们自己的需要呢?

2 扩展TensorRT C++ API 模型输入维度

在TensorRT C++ API的include目录下的NvInferRuntimeCommon.h文件定义了类Class Dims32,

//!
//! \class Dims
//! \brief Structure to define the dimensions of a tensor.
//!
//! TensorRT can also return an invalid dims structure. This structure is represented by nbDims == -1
//! and d[i] == 0 for all d.
//!
//! TensorRT can also return an "unknown rank" dims structure. This structure is represented by nbDims == -1
//! and d[i] == -1 for all d.
//!
class Dims32
{
public:
    //! The maximum rank (number of dimensions) supported for a tensor.
    static constexpr int32_t MAX_DIMS{8};
    //! The rank (number of dimensions).
    int32_t nbDims;
    //! The extent of each dimension.
    int32_t d[MAX_DIMS];
};

该类用于定义tensor的输入维度,从类定义上看,该类支持的最大维度为8。

在TensorRT C++ API的include目录下的NvInferLegacyDims.h定义了目前TensorRT所指的输入维度:

/*
 * Copyright 1993-2021 NVIDIA Corporation.  All rights reserved.
 *
 * NOTICE TO LICENSEE:
 *
 * This source code and/or documentation ("Licensed Deliverables") are
 * subject to NVIDIA intellectual property rights under U.S. and
 * international Copyright laws.
 *
 * These Licensed Deliverables contained herein is PROPRIETARY and
 * CONFIDENTIAL to NVIDIA and is being provided under the terms and
 * conditions of a form of NVIDIA software license agreement by and
 * between NVIDIA and Licensee ("License Agreement") or electronically
 * accepted by Licensee.  Notwithstanding any terms or conditions to
 * the contrary in the License Agreement, reproduction or disclosure
 * of the Licensed Deliverables to any third party without the express
 * written consent of NVIDIA is prohibited.
 *
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
 * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE.  IT IS
 * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
 * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
 * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
 * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
 * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
 * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
 * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
 * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
 * OF THESE LICENSED DELIVERABLES.
 *
 * U.S. Government End Users.  These Licensed Deliverables are a
 * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
 * 1995), consisting of "commercial computer software" and "commercial
 * computer software documentation" as such terms are used in 48
 * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
 * only as a commercial end item.  Consistent with 48 C.F.R.12.212 and
 * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
 * U.S. Government End Users acquire the Licensed Deliverables with
 * only those rights set forth herein.
 *
 * Any use of the Licensed Deliverables in individual and commercial
 * software must include, in the user documentation and internal
 * comments to the code, the above Disclaimer and U.S. Government End
 * Users Notice.
 */

#ifndef NV_INFER_LEGACY_DIMS_H
#define NV_INFER_LEGACY_DIMS_H

#include "NvInferRuntimeCommon.h"

//!
//! \file NvInferLegacyDims.h
//!
//! This file contains declarations of legacy dimensions types which use channel
//! semantics in their names, and declarations on which those types rely.
//!

//!
//! \namespace nvinfer1
//!
//! \brief The TensorRT API version 1 namespace.
//!
namespace nvinfer1
{
//!
//! \class Dims2
//! \brief Descriptor for two-dimensional data.
//!
class Dims2 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims2 object.
    //!
    Dims2()
        : Dims{2, {}}
    {
    }

    //!
    //! \brief Construct a Dims2 from 2 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //!
    Dims2(int32_t d0, int32_t d1)
        : Dims{2, {d0, d1}}
    {
    }
};

//!
//! \class DimsHW
//! \brief Descriptor for two-dimensional spatial data.
//!
class DimsHW : public Dims2
{
public:
    //!
    //! \brief Construct an empty DimsHW object.
    //!
    DimsHW()
        : Dims2()
    {
    }

    //!
    //! \brief Construct a DimsHW given height and width.
    //!
    //! \param height the height of the data
    //! \param width the width of the data
    //!
    DimsHW(int32_t height, int32_t width)
        : Dims2(height, width)
    {
    }

    //!
    //! \brief Get the height.
    //!
    //! \return The height.
    //!
    int32_t& h()
    {
        return d[0];
    }

    //!
    //! \brief Get the height.
    //!
    //! \return The height.
    //!
    int32_t h() const
    {
        return d[0];
    }

    //!
    //! \brief Get the width.
    //!
    //! \return The width.
    //!
    int32_t& w()
    {
        return d[1];
    }

    //!
    //! \brief Get the width.
    //!
    //! \return The width.
    //!
    int32_t w() const
    {
        return d[1];
    }
};

//!
//! \class Dims3
//! \brief Descriptor for three-dimensional data.
//!
class Dims3 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims3 object.
    //!
    Dims3()
        : Dims{3, {}}
    {
    }

    //!
    //! \brief Construct a Dims3 from 3 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //!
    Dims3(int32_t d0, int32_t d1, int32_t d2)
        : Dims{3, {d0, d1, d2}}
    {
    }
};

//!
//! \class Dims4
//! \brief Descriptor for four-dimensional data.
//!
class Dims4 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims4 object.
    //!
    Dims4()
        : Dims{4, {}}
    {
    }

    //!
    //! \brief Construct a Dims4 from 4 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //!
    Dims4(int32_t d0, int32_t d1, int32_t d2, int32_t d3)
        : Dims{4, {d0, d1, d2, d3}}
    {
    }
};

} // namespace nvinfer1

#endif // NV_INFER_LEGCY_DIMS_H

从上述文件的代码看,构建输入维度只需要继承类Dims,然后按定义进行初始化即可。所以为了TensortRT可以支持Dims5,Dims6,Dims7,Dims8等高输入维度,那么需要自定义扩展以上维度,扩展后的NvInferLegacyDims.h文件内容如下所示:

/*
 * Copyright 1993-2021 NVIDIA Corporation.  All rights reserved.
 *
 * NOTICE TO LICENSEE:
 *
 * This source code and/or documentation ("Licensed Deliverables") are
 * subject to NVIDIA intellectual property rights under U.S. and
 * international Copyright laws.
 *
 * These Licensed Deliverables contained herein is PROPRIETARY and
 * CONFIDENTIAL to NVIDIA and is being provided under the terms and
 * conditions of a form of NVIDIA software license agreement by and
 * between NVIDIA and Licensee ("License Agreement") or electronically
 * accepted by Licensee.  Notwithstanding any terms or conditions to
 * the contrary in the License Agreement, reproduction or disclosure
 * of the Licensed Deliverables to any third party without the express
 * written consent of NVIDIA is prohibited.
 *
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
 * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE.  IT IS
 * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
 * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
 * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
 * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
 * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
 * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
 * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
 * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
 * OF THESE LICENSED DELIVERABLES.
 *
 * U.S. Government End Users.  These Licensed Deliverables are a
 * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
 * 1995), consisting of "commercial computer software" and "commercial
 * computer software documentation" as such terms are used in 48
 * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
 * only as a commercial end item.  Consistent with 48 C.F.R.12.212 and
 * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
 * U.S. Government End Users acquire the Licensed Deliverables with
 * only those rights set forth herein.
 *
 * Any use of the Licensed Deliverables in individual and commercial
 * software must include, in the user documentation and internal
 * comments to the code, the above Disclaimer and U.S. Government End
 * Users Notice.
 */

#ifndef NV_INFER_LEGACY_DIMS_H
#define NV_INFER_LEGACY_DIMS_H

#include "NvInferRuntimeCommon.h"

//!
//! \file NvInferLegacyDims.h
//!
//! This file contains declarations of legacy dimensions types which use channel
//! semantics in their names, and declarations on which those types rely.
//!

//!
//! \namespace nvinfer1
//!
//! \brief The TensorRT API version 1 namespace.
//!
namespace nvinfer1
{
//!
//! \class Dims2
//! \brief Descriptor for two-dimensional data.
//!
class Dims2 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims2 object.
    //!
    Dims2()
        : Dims{2, {}}
    {
    }

    //!
    //! \brief Construct a Dims2 from 2 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //!
    Dims2(int32_t d0, int32_t d1)
        : Dims{2, {d0, d1}}
    {
    }
};

//!
//! \class DimsHW
//! \brief Descriptor for two-dimensional spatial data.
//!
class DimsHW : public Dims2
{
public:
    //!
    //! \brief Construct an empty DimsHW object.
    //!
    DimsHW()
        : Dims2()
    {
    }

    //!
    //! \brief Construct a DimsHW given height and width.
    //!
    //! \param height the height of the data
    //! \param width the width of the data
    //!
    DimsHW(int32_t height, int32_t width)
        : Dims2(height, width)
    {
    }

    //!
    //! \brief Get the height.
    //!
    //! \return The height.
    //!
    int32_t& h()
    {
        return d[0];
    }

    //!
    //! \brief Get the height.
    //!
    //! \return The height.
    //!
    int32_t h() const
    {
        return d[0];
    }

    //!
    //! \brief Get the width.
    //!
    //! \return The width.
    //!
    int32_t& w()
    {
        return d[1];
    }

    //!
    //! \brief Get the width.
    //!
    //! \return The width.
    //!
    int32_t w() const
    {
        return d[1];
    }
};

//!
//! \class Dims3
//! \brief Descriptor for three-dimensional data.
//!
class Dims3 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims3 object.
    //!
    Dims3()
        : Dims{3, {}}
    {
    }

    //!
    //! \brief Construct a Dims3 from 3 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //!
    Dims3(int32_t d0, int32_t d1, int32_t d2)
        : Dims{3, {d0, d1, d2}}
    {
    }
};

//!
//! \class Dims4
//! \brief Descriptor for four-dimensional data.
//!
class Dims4 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims4 object.
    //!
    Dims4()
        : Dims{4, {}}
    {
    }

    //!
    //! \brief Construct a Dims4 from 4 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //!
    Dims4(int32_t d0, int32_t d1, int32_t d2, int32_t d3)
        : Dims{4, {d0, d1, d2, d3}}
    {
    }
};

//!
//! \class Dims5
//! \brief Descriptor for four-dimensional data.
//!
class Dims5 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims5 object.
    //!
    Dims5()
    {
        nbDims = 5;
        for (int32_t i = 0; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }

    //!
    //! \brief Construct a Dims5 from 5 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //! \param d4 The fifth element.
    //!
    Dims5(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4)
    {
        nbDims = 5;
        d[0] = d0;
        d[1] = d1;
        d[2] = d2;
        d[3] = d3;
        d[4] = d4;
        for (int32_t i = nbDims; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }
};

//!
//! \class Dims6
//! \brief Descriptor for four-dimensional data.
//!
class Dims6 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims5 object.
    //!
    Dims6()
    {
        nbDims = 6;
        for (int32_t i = 0; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }

    //!
    //! \brief Construct a Dims5 from 5 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //! \param d4 The fifth element.
    //! \param d5 The sixth element.
    //!
    Dims6(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5)
    {
        nbDims = 6;
        d[0] = d0;
        d[1] = d1;
        d[2] = d2;
        d[3] = d3;
        d[4] = d4;
        d[5] = d5;
        for (int32_t i = nbDims; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }
};

//!
//! \class Dims7
//! \brief Descriptor for four-dimensional data.
//!
class Dims7 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims5 object.
    //!
    Dims7()
    {
        nbDims = 7;
        for (int32_t i = 0; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }

    //!
    //! \brief Construct a Dims5 from 5 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //! \param d4 The fifth element.
    //! \param d5 The sixth element.
    //! \param d6 The seventh element.
    //!
    Dims7(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5, int32_t d6)
    {
        nbDims = 7;
        d[0] = d0;
        d[1] = d1;
        d[2] = d2;
        d[3] = d3;
        d[4] = d4;
        d[5] = d5;
        d[6] = d6;
        for (int32_t i = nbDims; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }
};


//!
//! \class Dims8
//! \brief Descriptor for four-dimensional data.
//!
class Dims8 : public Dims
{
public:
    //!
    //! \brief Construct an empty Dims5 object.
    //!
    Dims8()
    {
        nbDims = 8;
        for (int32_t i = 0; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }

    //!
    //! \brief Construct a Dims5 from 5 elements.
    //!
    //! \param d0 The first element.
    //! \param d1 The second element.
    //! \param d2 The third element.
    //! \param d3 The fourth element.
    //! \param d4 The fifth element.
    //! \param d5 The sixth element.
    //! \param d6 The seventh element.
    //! \param d7 The eighth element.
    //!
    Dims8(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5, int32_t d6, int32_t d7)
    {
        nbDims = 8;
        d[0] = d0;
        d[1] = d1;
        d[2] = d2;
        d[3] = d3;
        d[4] = d4;
        d[5] = d5;
        d[6] = d6;
        d[7] = d7;
        for (int32_t i = nbDims; i < MAX_DIMS; ++i)
        {
            d[i] = 0;
        }
    }
};




} // namespace nvinfer1

#endif // NV_INFER_LEGCY_DIMS_H

将NvInferLegacyDims.h修改之后,重新编译即可使用所扩展的Dims5、Dims6、Dims7、Dims8的5维,6维,7维,8维网络输入维度。

如果有兴趣,可以访问我的个站:https://www.stubbornhuang.com/,更多干货!

### 回答1: 这个错误通常发生在卷积神经网络中,是因为输入和卷积核的尺寸不匹配导致的。具体来说,这个错误通常表示输出的大小与预期的大小不一致。在您的情况下,您的代码将输出大小计算为(6,6),但实际上输出的大小是(8,8)。因此,您需要检查您的卷积层的输入大小、卷积核的大小以及步幅等参数是否正确设置。您可以通过打印这些参数来查找问题所在。 ### 回答2: 这个错误是由于计算后的输出维度与预期的不匹配导致的。具体来说,预期输出的维度8x8,但实际计算得到的输出维度6x6。 造成这个问题的可能原因有两个:第一个可能是在计算输出时,某些参数或计算步骤导致了维度的缩减;第二个可能是在定义网络结构时,输出层的维度设置不正确。 要解决这个问题,可以按照以下步骤进行操作: 1. 检查网络结构定义中的输出层部分,确保输出层的维度设置正确,即与期望输出维度一致。 2. 检查计算输出的过程,查找可能导致维度缩减的地方,例如池化层、降采样操作等。确保在这些地方不会引起维度丢失。 3. 如果可能的话,可以使用调试工具检查计算过程中各个层的输出维度,找出导致不匹配的具体步骤。 4. 根据错误提示中给出的输出维度信息,进一步排查各层的输入维度是否正确传递,并确保网络的整体输入输出流程正确。 当以上的步骤中某一个或多个问题修复后,通常就可以解决这个运行时错误,保证输出维度与期望一致,从而让程序正常运行。 ### 回答3: 这个错误是由于计算的输出维度与预期的输出维度不匹配导致的。在这个具体的例子中,期望的输出维度是(8, 8),但实际计算得到的输出维度是(6, 6)。 出现这个错误可能有几种原因。首先,可能是由于输入数据的尺寸或形状不正确,导致计算得到的输出尺寸与预期的不一致。其次,可能是由于模型或算法中的参数或计算过程有误,导致输出维度计算错误。最后,可能是由于计算设备或框架的配置或限制导致计算过程中的错误。 要解决这个问题,首先需要检查输入数据的尺寸和形状是否正确。确保输入数据与模型或算法的要求相匹配。其次,检查模型或算法中的参数或计算过程是否正确,确保计算得到的输出尺寸与预期的一致。如果存在计算设备或框架的配置或限制问题,可能需要对其进行相应的调整或修改。 总之,这个错误是由于计算的输出维度与预期的输出维度不匹配所导致的。通过检查输入数据、模型或算法的要求以及计算设备或框架的配置,可以解决这个问题。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

HW140701

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值