合并多个 onnx

一个训练好的 onnx 中有效信息包含 graph 中的 node(位置关系),initializer(权重),input、output(它们不是node,是value info)。

假设有 3个 onnx,是具有相同结构、不同权重的模型(例如使用不同数据训练)。

原模型的输入名为 ‘img’,输出名为 ‘fea’。

新模型的输入和原模型一样,输出需要把原输出相加。

# -*- coding:utf-8 -*-
import onnx
import copy

list_modelname = ['./model1.onnx', './model2.onnx', './model3.onnx']
#模型中节点和输入输出名字都一样,合并后要加个前缀
list_prefix = ['m1_', 'm2_', 'm3_']

model_num = len(list_modelname)
input_name = 'img'
output_name = 'fea'

list_model = []
for i in range(model_num):
    model = onnx.load(list_modelname[i])
    list_model.append(model)

#0、构造新模型
modelX = onnx.ModelProto(ir_version=list_model[0].ir_version,
                         producer_name=list_model[0].producer_name,
                         producer_version=list_model[0].producer_version,
                         opset_import=list_model[0].opset_import)

#1、添加 input
model1_input_tensor_type = list_model[0].graph.input[0].type.tensor_type
input_elem_type = model1_input_tensor_type.elem_type
input_shape = []
for s in model1_input_tensor_type.shape.dim:
    if (s.dim_value > 0):
        input_shape.append(s.dim_value)
    else:
        input_shape.append(s.dim_param)   

modelX_input = onnx.helper.make_tensor_value_info(
                                input_name,
                                input_elem_type,
                                input_shape
                            )
modelX.graph.input.append(modelX_input)

#2、添加 output
model1_output_tensor_type = list_model[0].graph.output[0].type.tensor_type
output_elem_type = model1_output_tensor_type.elem_type
output_shape = []
for s in model1_output_tensor_type.shape.dim:
    if (s.dim_value > 0):
        output_shape.append(s.dim_value)
    else:        
        output_shape.append(s.dim_param)

modelX_output = onnx.helper.make_tensor_value_info(
                                output_name,
                                output_elem_type,
                                output_shape
                            )
modelX.graph.output.append(modelX_output)

#3、添加输出前的 add 节点
node = onnx.helper.make_node(
    'Add',
    name='add_fea',
    inputs=['m1_fea', 'm2_fea', 'm3_fea'],
    outputs=[output_name],
)
modelX.graph.node.append(node)

#4、添加训练好的模型中的节点和权重
for idx in range(model_num):
    model = list_model[idx]
    for node in model.graph.node:        
        for i in range(len(node.input)):
            if (node.input[i] != input_name):
                node.input[i] = list_prefix[idx] + node.input[i]

        for i in range(len(node.output)):
            node.output[i] = list_prefix[idx] + node.output[i]
            
        node.name = list_prefix[idx] + node.name        
        modelX.graph.node.append(node)

    for weight in model.graph.initializer:
        weight.name = list_prefix[idx] + weight.name
        modelX.graph.initializer.append(weight)

#5、保存新模型
onnx.save(modelX, './modelX.onnx')

参考:

https://hexdocs.pm/onnxs/Onnx.ModelProto.html

  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
在 C++ 中使用 ONNX Runtime 同时处理多个 batch size 的数据,你需要使用 ONNX Runtime 的 C++ API。具体步骤如下: 1. 加载模型。使用 `Ort::Env` 类创建一个运行环境,然后使用 `Ort::SessionOptions` 类设置会话选项,最后使用 `Ort::Session` 类加载模型。 ```C++ Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); Ort::SessionOptions session_options; Ort::Session session(env, model_path.c_str(), session_options); ``` 2. 获取输入和输出信息。使用 `Ort::Session` 的 `GetInputTypeInfo` 和 `GetOutputTypeInfo` 方法获取输入和输出的类型信息。 ```C++ auto input_info = session.GetInputTypeInfo(); auto output_info = session.GetOutputTypeInfo(); ``` 3. 准备输入数据。将多个 batch size 的数据拼接在一起,然后将拼接后的数据转换成 ONNX Runtime 所需的格式。 ```C++ // 将多个 batch size 的数据拼接在一起 std::vector<float> input_data; for (size_t i = 0; i < batch_sizes.size(); i++) { input_data.insert(input_data.end(), inputs[i].begin(), inputs[i].end()); } // 将拼接后的数据转换成 ONNX Runtime 所需的格式 std::vector<int64_t> input_dims = {batch_sizes.size(), input_size}; auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); auto input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_data.data(), input_data.size(), input_dims.data(), input_dims.size()); ``` 4. 执行推理。使用 `Ort::Session` 的 `Run` 方法执行推理,获取输出结果。 ```C++ // 执行推理 auto output_tensors = session.Run(run_options, input_names.data(), &input_tensor, input_names.size(), output_names.data(), output_names.size()); // 获取输出结果 for (size_t i = 0; i < output_tensors.size(); i++) { auto output_tensor = output_tensors[i].Get<Tensor>(); auto output_dims = output_tensor.Shape().GetDims(); auto output_size = output_tensor.Shape().Size(); std::vector<float> output_data(output_size); output_tensor.CopyTo(output_data.data(), output_size * sizeof(float)); // 将输出结果按照 batch size 分组 for (size_t j = 0; j < batch_sizes.size(); j++) { auto start_index = j * output_size / batch_sizes.size(); auto end_index = (j + 1) * output_size / batch_sizes.size(); auto output = std::vector<float>(output_data.begin() + start_index, output_data.begin() + end_index); // 处理输出结果 // ... } } ``` 需要注意的是,在拼接输入数据时,不同的 batch size 的数据要保证维度相同,即在各个维度上的大小应该一致。在处理输出结果时,需要将输出结果按照 batch size 分组,然后进行相应的处理。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值