关于tensorrt里面的wts校验

可以参考我改的项目:

https://github.com/lindsayshuo/yolov8-cls-tensorrtx

写完api搭建的网络后进行验证网络的传输shape,可以参考如下代码:

for (const auto& kv : weightMap) {
    if (kv.first.find("conv.weight") != std::string::npos || kv.first.find("linear.weight") != std::string::npos) { // 检查 conv.weight 或 linear.weight
        std::cout << "Weight name: " << kv.first << ", ";
        std::cout << "Count: " << kv.second.count << ", ";
        std::cout << "Type: " << (kv.second.type == nvinfer1::DataType::kFLOAT ? "FLOAT" :
                                  kv.second.type == nvinfer1::DataType::kHALF ? "HALF" : "INT8") << std::endl;
    }
}

执行完输出如下:

Loading weights: ../yolov8n-cls.wts
Weight name: model.0.conv.weight, Count: 432, Type: FLOAT
Weight name: model.1.conv.weight, Count: 4608, Type: FLOAT
Weight name: model.2.cv1.conv.weight, Count: 1024, Type: FLOAT
Weight name: model.2.cv2.conv.weight, Count: 1536, Type: FLOAT
Weight name: model.2.m.0.cv1.conv.weight, Count: 2304, Type: FLOAT
Weight name: model.2.m.0.cv2.conv.weight, Count: 2304, Type: FLOAT
Weight name: model.3.conv.weight, Count: 18432, Type: FLOAT
Weight name: model.4.cv1.conv.weight, Count: 4096, Type: FLOAT
Weight name: model.4.cv2.conv.weight, Count: 8192, Type: FLOAT
Weight name: model.4.m.0.cv1.conv.weight, Count: 9216, Type: FLOAT
Weight name: model.4.m.0.cv2.conv.weight, Count: 9216, Type: FLOAT
Weight name: model.4.m.1.cv1.conv.weight, Count: 9216, Type: FLOAT
Weight name: model.4.m.1.cv2.conv.weight, Count: 9216, Type: FLOAT
Weight name: model.5.conv.weight, Count: 73728, Type: FLOAT
Weight name: model.6.cv1.conv.weight, Count: 16384, Type: FLOAT
Weight name: model.6.cv2.conv.weight, Count: 32768, Type: FLOAT
Weight name: model.6.m.0.cv1.conv.weight, Count: 36864, Type: FLOAT
Weight name: model.6.m.0.cv2.conv.weight, Count: 36864, Type: FLOAT
Weight name: model.6.m.1.cv1.conv.weight, Count: 36864, Type: FLOAT
Weight name: model.6.m.1.cv2.conv.weight, Count: 36864, Type: FLOAT
Weight name: model.7.conv.weight, Count: 294912, Type: FLOAT
Weight name: model.8.cv1.conv.weight, Count: 65536, Type: FLOAT
Weight name: model.8.cv2.conv.weight, Count: 98304, Type: FLOAT
Weight name: model.8.m.0.cv1.conv.weight, Count: 147456, Type: FLOAT
Weight name: model.8.m.0.cv2.conv.weight, Count: 147456, Type: FLOAT
Weight name: model.9.conv.conv.weight, Count: 327680, Type: FLOAT
Weight name: model.9.linear.weight, Count: 1280000, Type: FLOAT
[03/10/2024-23:30:13] [W] [TRT] The implicit batch dimension mode has been deprecated. Please create the network with NetworkDefinitionCreationFlag::kEXPLICIT_BATCH flag whenever possible.
max_channels : [1280]
Input shape: [3, 224, 224]
Calculated channel width: 16
Maximum channel limit: 1280
conv0 Dimensions(3): [ 16 112 112 ]
Calculated channel width: 32
Maximum channel limit: 1280
conv1 Dimensions(3): [ 32 56 56 ]
Calculated channel width: 32
Maximum channel limit: 1280
Calculated channel width: 32
Maximum channel limit: 1280
conv2 Dimensions(3): [ 32 56 56 ]
Calculated channel width: 64
Maximum channel limit: 1280
conv3 Dimensions(3): [ 64 28 28 ]
Calculated channel width: 64
Maximum channel limit: 1280
Calculated channel width: 64
Maximum channel limit: 1280
conv4 Dimensions(3): [ 64 28 28 ]
Calculated channel width: 128
Maximum channel limit: 1280
conv5 Dimensions(3): [ 128 14 14 ]
Calculated channel width: 128
Maximum channel limit: 1280
Calculated channel width: 128
Maximum channel limit: 1280
conv6 Dimensions(3): [ 128 14 14 ]
Calculated channel width: 256
Maximum channel limit: 1280
conv7 Dimensions(3): [ 256 7 7 ]
Calculated channel width: 256
Maximum channel limit: 1280
Calculated channel width: 256
Maximum channel limit: 1280
conv8 Dimensions(3): [ 256 7 7 ]
conv_class Dimensions(3): [ 1280 9 9 ]
Dimensions of the output from pool2 layer: 1280 1 1 
Number of feature maps: 1
model.9.linear.weight count: 1280000
Shape of model.9.linear.weight: [1000 x 1280]
Output shape of yolo: [1000, 1, 1]
Building engine, please wait for a while...

就拿Weight name: model.0.conv.weight, Count: 432, Type: FLOAT来说,这个432等于16×3×3×3,具体来源可以根据如下代码查找onnx对应输出:

import onnx

model_in_file = 'yolov8n-cls.onnx'

if __name__ == "__main__":
    model = onnx.load(model_in_file)

    # 打印节点信息
    nodes = model.graph.node
    for node in nodes:
        if node.op_type == 'Conv':  # 检查节点是否为卷积操作
            for attribute in node.attribute:
                if attribute.name == 'strides':
                    # 获取步长
                    strides = attribute.ints
                    print(f'Node name: {node.name}, Strides: {strides}')
    nodnum = len(nodes)

    for nid in range(nodnum):
        if (nodes[nid].output[0] == 'stride_32'):
            print('Found stride_32: index = ', nid)
        else:
            print(nodes[nid].output)

    # 打印初始器信息
    inits = model.graph.initializer
    ininum = len(inits)

    for iid in range(ininum):
        el = inits[iid]
        print('name:', el.name, ' dtype:', el.data_type, ' dim:', el.dims)

    # 打印输出节点信息
    print(model.graph.output)

    # 打印输入的形状
    inputs = model.graph.input
    for input in inputs:
        # 获取输入的名字、类型和形状
        print('Input name:', input.name)
        try:
            # 输入的类型信息是包装过的,需要解包来获取
            tensor_type = input.type.tensor_type
            # ONNX使用枚举类型来表示数据类型
            dtype = onnx.TensorProto.DataType.Name(tensor_type.elem_type)
            # 获取形状信息
            shape = [dim.dim_value for dim in tensor_type.shape.dim]
            print(' Data type:', dtype)
            print(' Shape:', shape)
        except:
            print('Input shape is not fully defined.')

    print('Done')

现在看输出结果:

Node name: /model.0/conv/Conv, Strides: [2, 2]
Node name: /model.1/conv/Conv, Strides: [2, 2]
Node name: /model.2/cv1/conv/Conv, Strides: [1, 1]
Node name: /model.2/m.0/cv1/conv/Conv, Strides: [1, 1]
Node name: /model.2/m.0/cv2/conv/Conv, Strides: [1, 1]
Node name: /model.2/cv2/conv/Conv, Strides: [1, 1]
Node name: /model.3/conv/Conv, Strides: [2, 2]
Node name: /model.4/cv1/conv/Conv, Strides: [1, 1]
Node name: /model.4/m.0/cv1/conv/Conv, Strides: [1, 1]
Node name: /model.4/m.0/cv2/conv/Conv, Strides: [1, 1]
Node name: /model.4/m.1/cv1/conv/Conv, Strides: [1, 1]
Node name: /model.4/m.1/cv2/conv/Conv, Strides: [1, 1]
Node name: /model.4/cv2/conv/Conv, Strides: [1, 1]
Node name: /model.5/conv/Conv, Strides: [2, 2]
Node name: /model.6/cv1/conv/Conv, Strides: [1, 1]
Node name: /model.6/m.0/cv1/conv/Conv, Strides: [1, 1]
Node name: /model.6/m.0/cv2/conv/Conv, Strides: [1, 1]
Node name: /model.6/m.1/cv1/conv/Conv, Strides: [1, 1]
Node name: /model.6/m.1/cv2/conv/Conv, Strides: [1, 1]
Node name: /model.6/cv2/conv/Conv, Strides: [1, 1]
Node name: /model.7/conv/Conv, Strides: [2, 2]
Node name: /model.8/cv1/conv/Conv, Strides: [1, 1]
Node name: /model.8/m.0/cv1/conv/Conv, Strides: [1, 1]
Node name: /model.8/m.0/cv2/conv/Conv, Strides: [1, 1]
Node name: /model.8/cv2/conv/Conv, Strides: [1, 1]
Node name: /model.9/conv/conv/Conv, Strides: [1, 1]
['/model.0/conv/Conv_output_0']
['/model.0/act/Sigmoid_output_0']
['/model.0/act/Mul_output_0']
['/model.1/conv/Conv_output_0']
['/model.1/act/Sigmoid_output_0']
['/model.1/act/Mul_output_0']
['/model.2/cv1/conv/Conv_output_0']
['/model.2/cv1/act/Sigmoid_output_0']
['/model.2/cv1/act/Mul_output_0']
['onnx::Split_64']
['/model.2/Split_output_0', '/model.2/Split_output_1']
['/model.2/m.0/cv1/conv/Conv_output_0']
['/model.2/m.0/cv1/act/Sigmoid_output_0']
['/model.2/m.0/cv1/act/Mul_output_0']
['/model.2/m.0/cv2/conv/Conv_output_0']
['/model.2/m.0/cv2/act/Sigmoid_output_0']
['/model.2/m.0/cv2/act/Mul_output_0']
['/model.2/m.0/Add_output_0']
['/model.2/Concat_output_0']
['/model.2/cv2/conv/Conv_output_0']
['/model.2/cv2/act/Sigmoid_output_0']
['/model.2/cv2/act/Mul_output_0']
['/model.3/conv/Conv_output_0']
['/model.3/act/Sigmoid_output_0']
['/model.3/act/Mul_output_0']
['/model.4/cv1/conv/Conv_output_0']
['/model.4/cv1/act/Sigmoid_output_0']
['/model.4/cv1/act/Mul_output_0']
['onnx::Split_84']
['/model.4/Split_output_0', '/model.4/Split_output_1']
['/model.4/m.0/cv1/conv/Conv_output_0']
['/model.4/m.0/cv1/act/Sigmoid_output_0']
['/model.4/m.0/cv1/act/Mul_output_0']
['/model.4/m.0/cv2/conv/Conv_output_0']
['/model.4/m.0/cv2/act/Sigmoid_output_0']
['/model.4/m.0/cv2/act/Mul_output_0']
['/model.4/m.0/Add_output_0']
['/model.4/m.1/cv1/conv/Conv_output_0']
['/model.4/m.1/cv1/act/Sigmoid_output_0']
['/model.4/m.1/cv1/act/Mul_output_0']
['/model.4/m.1/cv2/conv/Conv_output_0']
['/model.4/m.1/cv2/act/Sigmoid_output_0']
['/model.4/m.1/cv2/act/Mul_output_0']
['/model.4/m.1/Add_output_0']
['/model.4/Concat_output_0']
['/model.4/cv2/conv/Conv_output_0']
['/model.4/cv2/act/Sigmoid_output_0']
['/model.4/cv2/act/Mul_output_0']
['/model.5/conv/Conv_output_0']
['/model.5/act/Sigmoid_output_0']
['/model.5/act/Mul_output_0']
['/model.6/cv1/conv/Conv_output_0']
['/model.6/cv1/act/Sigmoid_output_0']
['/model.6/cv1/act/Mul_output_0']
['onnx::Split_111']
['/model.6/Split_output_0', '/model.6/Split_output_1']
['/model.6/m.0/cv1/conv/Conv_output_0']
['/model.6/m.0/cv1/act/Sigmoid_output_0']
['/model.6/m.0/cv1/act/Mul_output_0']
['/model.6/m.0/cv2/conv/Conv_output_0']
['/model.6/m.0/cv2/act/Sigmoid_output_0']
['/model.6/m.0/cv2/act/Mul_output_0']
['/model.6/m.0/Add_output_0']
['/model.6/m.1/cv1/conv/Conv_output_0']
['/model.6/m.1/cv1/act/Sigmoid_output_0']
['/model.6/m.1/cv1/act/Mul_output_0']
['/model.6/m.1/cv2/conv/Conv_output_0']
['/model.6/m.1/cv2/act/Sigmoid_output_0']
['/model.6/m.1/cv2/act/Mul_output_0']
['/model.6/m.1/Add_output_0']
['/model.6/Concat_output_0']
['/model.6/cv2/conv/Conv_output_0']
['/model.6/cv2/act/Sigmoid_output_0']
['/model.6/cv2/act/Mul_output_0']
['/model.7/conv/Conv_output_0']
['/model.7/act/Sigmoid_output_0']
['/model.7/act/Mul_output_0']
['/model.8/cv1/conv/Conv_output_0']
['/model.8/cv1/act/Sigmoid_output_0']
['/model.8/cv1/act/Mul_output_0']
['onnx::Split_138']
['/model.8/Split_output_0', '/model.8/Split_output_1']
['/model.8/m.0/cv1/conv/Conv_output_0']
['/model.8/m.0/cv1/act/Sigmoid_output_0']
['/model.8/m.0/cv1/act/Mul_output_0']
['/model.8/m.0/cv2/conv/Conv_output_0']
['/model.8/m.0/cv2/act/Sigmoid_output_0']
['/model.8/m.0/cv2/act/Mul_output_0']
['/model.8/m.0/Add_output_0']
['/model.8/Concat_output_0']
['/model.8/cv2/conv/Conv_output_0']
['/model.8/cv2/act/Sigmoid_output_0']
['/model.8/cv2/act/Mul_output_0']
['/model.9/conv/conv/Conv_output_0']
['/model.9/conv/act/Sigmoid_output_0']
['/model.9/conv/act/Mul_output_0']
['/model.9/pool/GlobalAveragePool_output_0']
['/model.9/Flatten_output_0']
['/model.9/linear/Gemm_output_0']
['output0']
name: model.0.conv.weight  dtype: 1  dim: [16, 3, 3, 3]
name: model.0.conv.bias  dtype: 1  dim: [16]
name: model.1.conv.weight  dtype: 1  dim: [32, 16, 3, 3]
name: model.1.conv.bias  dtype: 1  dim: [32]
name: model.2.cv1.conv.weight  dtype: 1  dim: [32, 32, 1, 1]
name: model.2.cv1.conv.bias  dtype: 1  dim: [32]
name: model.2.cv2.conv.weight  dtype: 1  dim: [32, 48, 1, 1]
name: model.2.cv2.conv.bias  dtype: 1  dim: [32]
name: model.2.m.0.cv1.conv.weight  dtype: 1  dim: [16, 16, 3, 3]
name: model.2.m.0.cv1.conv.bias  dtype: 1  dim: [16]
name: model.2.m.0.cv2.conv.weight  dtype: 1  dim: [16, 16, 3, 3]
name: model.2.m.0.cv2.conv.bias  dtype: 1  dim: [16]
name: model.3.conv.weight  dtype: 1  dim: [64, 32, 3, 3]
name: model.3.conv.bias  dtype: 1  dim: [64]
name: model.4.cv1.conv.weight  dtype: 1  dim: [64, 64, 1, 1]
name: model.4.cv1.conv.bias  dtype: 1  dim: [64]
name: model.4.cv2.conv.weight  dtype: 1  dim: [64, 128, 1, 1]
name: model.4.cv2.conv.bias  dtype: 1  dim: [64]
name: model.4.m.0.cv1.conv.weight  dtype: 1  dim: [32, 32, 3, 3]
name: model.4.m.0.cv1.conv.bias  dtype: 1  dim: [32]
name: model.4.m.0.cv2.conv.weight  dtype: 1  dim: [32, 32, 3, 3]
name: model.4.m.0.cv2.conv.bias  dtype: 1  dim: [32]
name: model.4.m.1.cv1.conv.weight  dtype: 1  dim: [32, 32, 3, 3]
name: model.4.m.1.cv1.conv.bias  dtype: 1  dim: [32]
name: model.4.m.1.cv2.conv.weight  dtype: 1  dim: [32, 32, 3, 3]
name: model.4.m.1.cv2.conv.bias  dtype: 1  dim: [32]
name: model.5.conv.weight  dtype: 1  dim: [128, 64, 3, 3]
name: model.5.conv.bias  dtype: 1  dim: [128]
name: model.6.cv1.conv.weight  dtype: 1  dim: [128, 128, 1, 1]
name: model.6.cv1.conv.bias  dtype: 1  dim: [128]
name: model.6.cv2.conv.weight  dtype: 1  dim: [128, 256, 1, 1]
name: model.6.cv2.conv.bias  dtype: 1  dim: [128]
name: model.6.m.0.cv1.conv.weight  dtype: 1  dim: [64, 64, 3, 3]
name: model.6.m.0.cv1.conv.bias  dtype: 1  dim: [64]
name: model.6.m.0.cv2.conv.weight  dtype: 1  dim: [64, 64, 3, 3]
name: model.6.m.0.cv2.conv.bias  dtype: 1  dim: [64]
name: model.6.m.1.cv1.conv.weight  dtype: 1  dim: [64, 64, 3, 3]
name: model.6.m.1.cv1.conv.bias  dtype: 1  dim: [64]
name: model.6.m.1.cv2.conv.weight  dtype: 1  dim: [64, 64, 3, 3]
name: model.6.m.1.cv2.conv.bias  dtype: 1  dim: [64]
name: model.7.conv.weight  dtype: 1  dim: [256, 128, 3, 3]
name: model.7.conv.bias  dtype: 1  dim: [256]
name: model.8.cv1.conv.weight  dtype: 1  dim: [256, 256, 1, 1]
name: model.8.cv1.conv.bias  dtype: 1  dim: [256]
name: model.8.cv2.conv.weight  dtype: 1  dim: [256, 384, 1, 1]
name: model.8.cv2.conv.bias  dtype: 1  dim: [256]
name: model.8.m.0.cv1.conv.weight  dtype: 1  dim: [128, 128, 3, 3]
name: model.8.m.0.cv1.conv.bias  dtype: 1  dim: [128]
name: model.8.m.0.cv2.conv.weight  dtype: 1  dim: [128, 128, 3, 3]
name: model.8.m.0.cv2.conv.bias  dtype: 1  dim: [128]
name: model.9.conv.conv.weight  dtype: 1  dim: [1280, 256, 1, 1]
name: model.9.conv.conv.bias  dtype: 1  dim: [1280]
name: model.9.linear.weight  dtype: 1  dim: [1000, 1280]
name: model.9.linear.bias  dtype: 1  dim: [1000]
[name: "output0"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 1000
      }
    }
  }
}
]
Input name: images
 Data type: FLOAT
 Shape: [1, 3, 224, 224]
Done

显而易见 name: model.0.conv.weight dtype: 1 dim: [16, 3, 3, 3],它的内积就是Weight name: model.0.conv.weight, Count: 432, Type: FLOAT

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

lindsayshuo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值