onnx新增节点
1 首先查找onnx算子,各算子的要求以及作用
2 利用下列算法进行添加(以在yolox中添加resize节点为例):
import onnx
import numpy as np
from onnx import AttributeProto, TensorProto, GraphProto
onnx_path = "D:\TwoModel2One\yoloAndeff\PlusTwoOnnx\GuDingChiCun\yolox_s.onnx"
onnx_model = onnx.load(onnx_path)
graph = onnx_model.graph
node = onnx_model.graph.node
print("graph_input:", graph.input)
print("graph_output:", graph.output)
new_shapes = {
"images": ["batch", 3, 540, 960],
}
for _input in graph.input:
print("_input:", _input.name)
tensor_shape_proto = _input.type.tensor_type.shape
new_shape = new_shapes[_input.name]
# delete old shape
elem_num = len(tensor_shape_proto.dim)
for i in reversed(range(elem_num)):
del tensor_shape_proto.dim[i]
for i, d in enumerate(new_shape):
dim = tensor_shape_proto.dim.add()
if d is None:
d = -1
if isinstance(d, int):
dim.dim_value = d
elif isinstance(d, str):
dim.dim_param = d
else:
raise ValueError(f"invalid shape: {new_shape}")
print("updated graph_input:", onnx_model.graph.input)
resize_node = onnx.helper.make_tensor(name='scales',
data_type=onnx.TensorProto.FLOAT,
dims=[4],
vals = np.array([1.0, 1.0, 1.1852,0.6667], dtype=np.float32)
)
graph.initializer.append(resize_node)
new_node = onnx.helper.make_node(
"Resize",
inputs=["images", "", "scales"],
outputs=["resized_images"],
mode="linear",
)
nodes = [new_node]
graph.node.insert(0, new_node)
#
for i in range(len(node)):
if node[i].op_type == "Slice" and node[i].name == "Slice_4" :
print(i)
node[i].input[0] = "resized_images"
if node[i].op_type == "Slice" and node[i].name == "Slice_14":
print(i)
node[i].input[0] = "resized_images"
if node[i].op_type == "Slice" and node[i].name == "Slice_24":
print(i)
node[i].input[0] = "resized_images"
if node[i].op_type == "Slice" and node[i].name == "Slice_34":
print(i)
node[i].input[0] = "resized_images"
graph = onnx.helper.make_graph(graph.node, graph.name, graph.input, graph.output, graph.initializer)
info_model = onnx.helper.make_model(graph)
onnx_model = onnx.shape_inference.infer_shapes(info_model)
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, "resize_yolox.onnx")