修改onnx模型中间节点命名(包含输入、输出重命名)

来源:Paddle2ONNX

Paddle2ONNX/tools/onnx/README.md at develop · PaddlePaddle/Paddle2ONNX · GitHub

依赖:import onnx

python rename_onnx_model.py --model model.onnx --origin_names x y z --new_names x1 y1 z1 --save_file new_model.onnx

其中 origin_names 和 new_names,前者表示原模型中各个命名(可指定多个),后者表示新命名,两个参数指定的命名个数需要相同 

 

import argparse
import sys


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model',
        required=True,
        help='Path of directory saved the input model.')
    parser.add_argument(
        '--origin_names',
        required=True,
        nargs='+',
        help='The original name you want to modify.')
    parser.add_argument(
        '--new_names',
        required=True,
        nargs='+',
        help='The new name you want change to, the number of new_names should be same with the number of origin_names'
    )
    parser.add_argument(
        '--save_file', required=True, help='Path to save the new onnx model.')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_arguments()
    import onnx
    model = onnx.load(args.model)
    output_tensor_names = set()
    for ipt in model.graph.input:
        output_tensor_names.add(ipt.name)
    for node in model.graph.node:
        for out in node.output:
            output_tensor_names.add(out)

    for origin_name in args.origin_names:
        if origin_name not in output_tensor_names:
            print("[ERROR] Cannot find tensor name '{}' in onnx model graph.".
                  format(origin_name))
            sys.exit(-1)
    if len(set(args.origin_names)) < len(args.origin_names):
        print(
            "[ERROR] There's dumplicate name in --origin_names, which is not allowed."
        )
        sys.exit(-1)
    if len(args.new_names) != len(args.origin_names):
        print(
            "[ERROR] Number of --new_names must be same with the number of --origin_names."
        )
        sys.exit(-1)
    if len(set(args.new_names)) < len(args.new_names):
        print(
            "[ERROR] There's dumplicate name in --new_names, which is not allowed."
        )
        sys.exit(-1)
    for new_name in args.new_names:
        if new_name in output_tensor_names:
            print(
                "[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed."
            )
            sys.exit(-1)

    for i, ipt in enumerate(model.graph.input):
        if ipt.name in args.origin_names:
            idx = args.origin_names.index(ipt.name)
            model.graph.input[i].name = args.new_names[idx]

    for i, node in enumerate(model.graph.node):
        for j, ipt in enumerate(node.input):
            if ipt in args.origin_names:
                idx = args.origin_names.index(ipt)
                model.graph.node[i].input[j] = args.new_names[idx]
        for j, out in enumerate(node.output):
            if out in args.origin_names:
                idx = args.origin_names.index(out)
                model.graph.node[i].output[j] = args.new_names[idx]

    for i, out in enumerate(model.graph.output):
        if out.name in args.origin_names:
            idx = args.origin_names.index(out.name)
            model.graph.output[i].name = args.new_names[idx]

    onnx.checker.check_model(model)
    onnx.save(model, args.save_file)
    print("[Finished] The new model saved in {}.".format(args.save_file))
    print("[DEBUG INFO] The inputs of new model: {}".format(
        [x.name for x in model.graph.input]))
    print("[DEBUG INFO] The outputs of new model: {}".format(
        [x.name for x in model.graph.output]))
修改前

修改后
可以使用 ONNX Runtime 的 API 来删除 ONNX 模型中的多余节点,并将其导出为 TensorRT 引擎。以下是一些步骤: 1. 加载 ONNX 模型 首先,使用 ONNX Runtime 的 Python API 加载 ONNX 模型。可以使用以下代码: ```python import onnx import onnxruntime as ort # Load the ONNX model onnx_model = onnx.load("model.onnx") ``` 2. 删除多余节点 使用 ONNX Runtime 的 API,可以轻松删除 ONNX 模型中的多余节点。可以使用以下代码: ```python # Create a new ONNX model without the unnecessary nodes inputs = ["input_0"] outputs = ["output_0"] new_model = ort.quantization.quantize_dynamic(onnx_model, inputs=inputs, outputs=outputs) ``` 在这个例子中,我们使用 ONNX Runtime 的 `quantize_dynamic` API 来删除模型中的多余节点。我们还指定了输入输出节点的名称。 3. 导出 TensorRT 引擎 使用 TensorRT 的 ONNX Parser,可以将 ONNX 模型解析为 TensorRT 的网络表示形式。可以使用以下代码将新的 ONNX 模型导出为 TensorRT 引擎: ```python import tensorrt as trt # Create a TensorRT builder builder = trt.Builder(TRT_LOGGER) # Create a TensorRT network from the ONNX model network = builder.create_network() parser = trt.OnnxParser(network, TRT_LOGGER) parser.parse_from_string(new_model.SerializeToString()) # Build an engine from the TensorRT network engine = builder.build_cuda_engine(network) ``` 在这个例子中,我们使用 TensorRT 的 Python API 创建一个 TensorRT builder 和一个 TensorRT network。然后,使用 TensorRT 的 ONNX Parser 将新的 ONNX 模型解析为 TensorRT 的网络表示形式,并将其添加到 TensorRT network 中。最后,使用 TensorRT builder 构建一个 TensorRT 引擎。 注意,这个例子中使用的是 `parse_from_string` 方法来解析 ONNX 模型。这是因为我们已经使用 ONNX Runtime 对模型进行了修改。如果您没有修改模型,则可以使用 `parse` 方法来解析原始 ONNX 模型。 4. 运行 TensorRT 引擎 构建完 TensorRT 引擎后,可以使用与前面例子中相同的代码来运行 TensorRT 推理。 ```python import pycuda.driver as cuda import pycuda.autoinit import numpy as np # Load the engine with open("engine.plan", "rb") as f: engine_data = f.read() engine = runtime.deserialize_cuda_engine(engine_data) # Allocate input and output buffers on the GPU input_bindings = [] output_bindings = [] stream = cuda.Stream() for binding in engine: size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size dtype = trt.nptype(engine.get_binding_dtype(binding)) if engine.binding_is_input(binding): input_bindings.append(cuda.mem_alloc(size * dtype.itemsize)) else: output_bindings.append(cuda.mem_alloc(size * dtype.itemsize)) # Load input data to the GPU input buffer input_data = np.random.randn(batch_size, input_size) cuda.memcpy_htod(input_bindings[0], input_data.flatten().astype(np.float32)) # Run inference context = engine.create_execution_context() context.execute_async_v2(bindings=input_bindings + output_bindings, stream_handle=stream.handle) cuda.streams.synchronize() # Get the output data from the GPU output buffer output_data = np.empty((batch_size, output_size), dtype=np.float32) cuda.memcpy_dtoh(output_data.flatten(), output_bindings[0]) ``` 在这个过程中,首先使用 TensorRT 的 Python API 加载 TensorRT 引擎。然后,使用 PyCUDA 分配输入输出缓冲区,并将输入数据从主机(CPU)传输到设备(GPU)。接下来,使用 TensorRT 的 Python API 创建一个 TensorRT 执行上下文,并在 GPU 上异步执行 TensorRT 推理。最后,使用 PyCUDA 将输出数据从设备(GPU)传输到主机(CPU)。 这就是如何使用 ONNX Runtime API 删除 ONNX 模型中的多余节点,并将其导出为 TensorRT 引擎。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值