input: dynamic input is missing dimensions in profile

input: dynamic input is missing dimensions in profile

onnx2trt代码报错:

import numpy as np
import tensorrt as trt
import os
import pycuda.driver as cuda
import argparse


def GiB(val):
    return val * 1 << 30

def ONNX_build_engine(onnx_file_path, write_engine=True):
    # :return: engine

    G_LOGGER = trt.Logger(trt.Logger.WARNING)
    # 1、动态输入第一点必须要写的
    explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    batch_size = 8  # trt推理时最大支持的batchsize
    with trt.Builder(G_LOGGER) as builder, builder.create_network(explicit_batch) as network, trt.OnnxParser(network,
                                                                                                             G_LOGGER) as parser:
        builder.max_batch_size = batch_size
        config = builder.create_builder_config()
        config.max_workspace_size = GiB(2)
        config.set_flag(trt.BuilderFlag.FP16)
        print('Loading ONNX file from path {}...'.format(onnx_file_path))
        with open(onnx_file_path, 'rb') as model:
            print('Beginning ONNX file parsing')
            parser.parse(model.read())
        print('Completed parsing of ONNX file')
        print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
        # 重点
        profile = builder.create_optimization_profile()  # 动态输入时候需要 分别为最小输入、常规输入、最大输入
        # 有几个输入就要写几个profile.set_shape 名字和转onnx的时候要对应
        # tensorrt6以后的版本是支持动态输入的,需要给每个动态输入绑定一个profile,用于指定最小值,常规值和最大值,如果超出这个范围会报异常。
        profile.set_shape("input", (1, 3, 128, 128), (4, 3, 128, 128), (16, 3, 128, 128))
        config.add_optimization_profile(profile)

        engine = builder.build_engine(network, config)
        print("Completed creating Engine")
        # 保存engine文件
        if write_engine:
            engine_file_path = 'efficientnet_b1.trt'
            with open(engine_file_path, "wb") as f:
                f.write(engine.serialize())
        return engine


onnx_file_path = r'skipnet_0712.onnx'
onnx_file_path = r'model2.onnx'
onnx_file_path = r'skip_simp2.onnx'
# onnx_file_path = r'mobileone_0713.onnx'
write_engine = True
engine = ONNX_build_engine(onnx_file_path, write_engine)

原错误代码:

profile.set_shape("inputs", (1, 3, 240, 240), (8, 3, 240, 240), (16, 3, 480, 480))

改之后代码:

profile.set_shape("input", (1, 3, 128, 128), (8, 3, 128, 128), (16, 3, 128, 128))

input就是onnx的输入参数,和output是输出参数

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值