Tensorrt8 onnx-to-tensorrt

该代码尝试加载已序列化的TensorRT引擎,若不存在则构建新引擎。它首先解析ONNX模型,然后构建网络并设置最大工作区大小、批处理尺寸和FP16模式。如果遇到`cuMemcpyHtoDAsync`错误,通常是因为网络输入维度不匹配。确保输入维度与TensorRT模型一致,并在问题解决后重新生成模型。
摘要由CSDN通过智能技术生成
def get_engine(onnx_file_path, engine_file_path=""):
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""

    def build_engine():
        """Takes an ONNX file and creates a TensorRT engine to run inference with"""
        with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
            EXPLICIT_BATCH
        ) as network, builder.create_builder_config() as config, trt.OnnxParser(
            network, TRT_LOGGER
        ) as parser, trt.Runtime(
            TRT_LOGGER
        ) as runtime:
            config.max_workspace_size = 1 << 32  # 4GB
            builder.max_batch_size = 1
            config.set_flag(trt.BuilderFlag.FP16)
            # Parse model file
            if not os.path.exists(onnx_file_path):
                print(
                    "ONNX file {} not found, please run yolov5_to_onnx.py first to generate it.".format(onnx_file_path)
                )
                exit(0)
            print("Loading ONNX file from path {}...".format(onnx_file_path))
            with open(onnx_file_path, "rb") as model:
                print("Beginning ONNX file parsing")
                if not parser.parse(model.read()):
                    print("ERROR: Failed to parse the ONNX file.")
                    for error in range(parser.num_errors):
                        print(parser.get_error(error))
                    return None
            network.get_input(0).shape = [1, 3, 640, 640]

            print("Completed parsing of ONNX file")
            print("Building an engine from file {}; this may take a while...".format(onnx_file_path))
            mmyolo = builder.build_serialized_network(network, config)
            engine = runtime.deserialize_cuda_engine(mmyolo)
            print("Completed creating Engine")
            with open(engine_file_path, "wb") as f:
                f.write(mmyolo)
            return engine

    if os.path.exists(engine_file_path):
        # If a serialized engine exists, use it instead of building an engine.
        print("Reading engine from file {}".format(engine_file_path))
        with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            return runtime.deserialize_cuda_engine(f.read())
    else:
        return build_engine()

tensorrt8主要在max_batch_size以及fp16_mode等的设置上有所不同,代码如上。参考链接https://blog.csdn.net/weixin_42492254/article/details/125319112
如果出现该错误:

pycuda._driver.LogicError: cuMemcpyHtoDAsync failed: invalid argument

那一定是网络输入维度有问题(还有输入类型也可能是之一,不过我碰到的是维度地),一定要和tensorrt模型输入保持一致,如果不一致的话记得检查输入问题,或者onnx->tensorrt本身维度指定有问题,维度改过来后,千万记得重新生成一遍tensorrt!!!

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值