一个训练好的 onnx 中有效信息包含 graph 中的 node(位置关系),initializer(权重),input、output(它们不是node,是value info)。
假设有 3个 onnx,是具有相同结构、不同权重的模型(例如使用不同数据训练)。
原模型的输入名为 ‘img’,输出名为 ‘fea’。
新模型的输入和原模型一样,输出需要把原输出相加。
# -*- coding:utf-8 -*-
import onnx
import copy
list_modelname = ['./model1.onnx', './model2.onnx', './model3.onnx']
#模型中节点和输入输出名字都一样,合并后要加个前缀
list_prefix = ['m1_', 'm2_', 'm3_']
model_num = len(list_modelname)
input_name = 'img'
output_name = 'fea'
list_model = []
for i in range(model_num):
model = onnx.load(list_modelname[i])
list_model.append(model)
#0、构造新模型
modelX = onnx.ModelProto(ir_version=list_model[0].ir_version,
producer_name=list_model[0].producer_name,
producer_version=list_model[0].producer_version,
opset_import=list_model[0].opset_import)
#1、添加 input
model1_input_tensor_type = list_model[0].graph.input[0].type.tensor_type
input_elem_type = model1_input_tensor_type.elem_type
input_shape = []
for s in model1_input_tensor_type.shape.dim:
if (s.dim_value > 0):
input_shape.append(s.dim_value)
else:
input_shape.append(s.dim_param)
modelX_input = onnx.helper.make_tensor_value_info(
input_name,
input_elem_type,
input_shape
)
modelX.graph.input.append(modelX_input)
#2、添加 output
model1_output_tensor_type = list_model[0].graph.output[0].type.tensor_type
output_elem_type = model1_output_tensor_type.elem_type
output_shape = []
for s in model1_output_tensor_type.shape.dim:
if (s.dim_value > 0):
output_shape.append(s.dim_value)
else:
output_shape.append(s.dim_param)
modelX_output = onnx.helper.make_tensor_value_info(
output_name,
output_elem_type,
output_shape
)
modelX.graph.output.append(modelX_output)
#3、添加输出前的 add 节点
node = onnx.helper.make_node(
'Add',
name='add_fea',
inputs=['m1_fea', 'm2_fea', 'm3_fea'],
outputs=[output_name],
)
modelX.graph.node.append(node)
#4、添加训练好的模型中的节点和权重
for idx in range(model_num):
model = list_model[idx]
for node in model.graph.node:
for i in range(len(node.input)):
if (node.input[i] != input_name):
node.input[i] = list_prefix[idx] + node.input[i]
for i in range(len(node.output)):
node.output[i] = list_prefix[idx] + node.output[i]
node.name = list_prefix[idx] + node.name
modelX.graph.node.append(node)
for weight in model.graph.initializer:
weight.name = list_prefix[idx] + weight.name
modelX.graph.initializer.append(weight)
#5、保存新模型
onnx.save(modelX, './modelX.onnx')
参考: