onnx新增节点

onnx新增节点

1 首先查找onnx算子,各算子的要求以及作用
2 利用下列算法进行添加(以在yolox中添加resize节点为例):

import onnx
import numpy as np
from onnx import AttributeProto, TensorProto, GraphProto


onnx_path = "D:\TwoModel2One\yoloAndeff\PlusTwoOnnx\GuDingChiCun\yolox_s.onnx"
onnx_model = onnx.load(onnx_path)
graph = onnx_model.graph
node = onnx_model.graph.node

print("graph_input:", graph.input)
print("graph_output:", graph.output)

new_shapes = {
    "images": ["batch", 3, 540, 960],
}

for _input in graph.input:
    print("_input:", _input.name)
    tensor_shape_proto = _input.type.tensor_type.shape

    new_shape = new_shapes[_input.name]
    # delete old shape
    elem_num = len(tensor_shape_proto.dim)
    for i in reversed(range(elem_num)):
        del tensor_shape_proto.dim[i]

    for i, d in enumerate(new_shape):
        dim = tensor_shape_proto.dim.add()
        if d is None:
            d = -1
        if isinstance(d, int):
            dim.dim_value = d
        elif isinstance(d, str):
            dim.dim_param = d
        else:
            raise ValueError(f"invalid shape: {new_shape}")

print("updated graph_input:", onnx_model.graph.input)

resize_node = onnx.helper.make_tensor(name='scales',
                                      data_type=onnx.TensorProto.FLOAT,
                                      dims=[4],
                                      vals = np.array([1.0, 1.0, 1.1852,0.6667], dtype=np.float32)
                                        )

graph.initializer.append(resize_node)

new_node = onnx.helper.make_node(
    "Resize",
    inputs=["images", "", "scales"],
    outputs=["resized_images"],
    mode="linear",
)

nodes = [new_node]
graph.node.insert(0, new_node)
#
for i in range(len(node)):
    if node[i].op_type == "Slice" and node[i].name == "Slice_4" :
        print(i)
        node[i].input[0] = "resized_images"
    if node[i].op_type == "Slice" and node[i].name == "Slice_14":
        print(i)
        node[i].input[0] = "resized_images"
    if node[i].op_type == "Slice" and node[i].name == "Slice_24":
        print(i)
        node[i].input[0] = "resized_images"
    if node[i].op_type == "Slice" and node[i].name == "Slice_34":
        print(i)
        node[i].input[0] = "resized_images"


graph = onnx.helper.make_graph(graph.node, graph.name, graph.input, graph.output, graph.initializer)
info_model = onnx.helper.make_model(graph)
onnx_model = onnx.shape_inference.infer_shapes(info_model)

onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, "resize_yolox.onnx")


  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值