因项目需求,需要对以前旧的模型在新的平台上进行适配。因为后处理修改起来较麻烦,故直接在模型中添加后处理模块。
本文主要参考官方文档:https://onnx.ai/onnx/api/
一、需求:
如下图为原始的模型输出结构:
目标模型结构:
从上面两张图,可以知道我们后面的主要工作为:
- 删除以前的输出;
- 添加几个新的算子,如Transpose,Shape,Gather,Unsqueeze,Concat等;
- 重新定义一个输出;
二、 操作步骤
1. 加载模型,输出模型信息
import onnx
onnx_path = 'path/to/model.onnx'
onnx_model = onnx.load(onnx_path) #加载模型
graph = onnx_model.graph # 获取graph
# 打印模型结构
nodes = graph.node
for i,node in enumrate(nodes):
print(i,node)
2. 更改输出
# 删除所有输出
out = onnx_model.graph.output
for i in range(len(out)):
del out[0]
# 创建新的输出,并添加进模型的输出中
out_cls = helper.make_tensor_value_info("cls", TensorProto.FLOAT, [1, 16800, 2])
onnx_model.graph.output.append(out_cls)
3. 创建节点
创建节点如下所示,其他几种节点也是相同操作。
# 创建一个Transpose节点,并指定
node_def = helper.make_node(
"Transpose", # name
["input"], # 指定inputs
["output"], # 指定outputs
perm = [0,2,3,1], # attributes
)
# update: 添加几个新的node的创建方式
# constant node
coffi0_shape_np = np.array([1, 32, 3840]).astype(np.int64)
coffi0_shape = helper.make_node(
"Constant",
inputs=[],
outputs=["coffi0_shape"],
name="/model.22/cv4.0/cv4.0.2/Constant",
value=onnx.numpy_helper.from_array(coffi0_shape_np),
)
graph.node.insert(182, coffi0_shape)
# reshape
coffi0_reshape = helper.make_node(
"Reshape",
["/model.22/cv4.0/cv4.0.2/Conv_output_0", "coffi0_shape"],
["coffi0_reshape"],
name="/model.22/cv4.0/cv4.0.2/Reshape",
)
graph.node.insert(183, coffi0_reshape)
# transpose
coffi0_transpose = onnx.helper.make_node(
"Transpose",
inputs=["coffi0_reshape"],
outputs=["coffi0_transpose"],
perm=[0, 2, 1],
name="/model.22/cv4.0/cv4.0.2/transpose",
)
graph.node.insert(184, coffi0_transpose)
4. 检测模型
调用onnx.checker检测模型结构是否正确
try:
onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
print(f"The model is invalid: {e}")
else:
print("The model is valid!")
5. 保存模型
onnx.save(onnx_model, "path/to/the/optimized/model.onnx")