修改onnx模型node

一、先列出有价值的参考链接/可学习的链接

  • onnx_python_examples 范例(已fock):链接
  • 一个可视化操作onnx 的git【一些op没实现】:链接
  • 一个比较好的知乎学习:链接
  • 比较好的一个csdn上对修改onnx的总结:链接
  • 一个知乎的简单总结:链接
  • 另一个csdn总结:链接
  • 又一个github修改shape的code:链接

二、代码整理

2.1 修改输入输出节点名称以及模型名称

import onnx

onnx_model = onnx.load('/root/Desktop/qxc_0613/mt_hmr_hardnet_rename.onnx')
export_model = '/root/Desktop/qxc_0613/mt_hmr_hardnet_rename_final.onnx'

# endpoint_names = ['input:input.1', 'POSE:648', 'SHAPE:570', 'CAM:572']
endpoint_names = ['input:input.1', 'POSE:648', 'SHAPE:570', 'CAM:572']

model_name = 'human_params_estimation'
#
# 修改以endpoint_names中节点作为input/output节点对名称
for i in range(len(onnx_model.graph.node)):
  for j in range(len(onnx_model.graph.node[i].input)):
    for endpoint_name in endpoint_names:
      if onnx_model.graph.node[i].input[j] == endpoint_name.split(':')[1]: # 改之前对in为==,防止graph中名为空字符串在in判断中成立
        print('-'*60)
        print("node name: ", onnx_model.graph.node[i].name)
        print("node input-------: ", onnx_model.graph.node[i].input[j])
        print("node input all: ", onnx_model.graph.node[i].input)
        print("node output all: ", onnx_model.graph.node[i].output)

        onnx_model.graph.node[i].input[j] = endpoint_name.split(':')[0] # onnx_model.graph.node[i].input[j].split(':')[0]

  for j in range(len(onnx_model.graph.node[i].output)):
    for endpoint_name in endpoint_names:
      if onnx_model.graph.node[i].output[j] == endpoint_name.split(':')[1]:
        print('-'*60)
        print("node: ", onnx_model.graph.node[i].name)
        print("node output-----: ", onnx_model.graph.node[i].output[j])
        print("node input all: ", onnx_model.graph.node[i].input)
        print("node output all: ", onnx_model.graph.node[i].output)

        onnx_model.graph.node[i].output[j] = endpoint_name.split(':')[0]

# 修改endpoint_names名称
for i in range(len(onnx_model.graph.input)):
  for endpoint_name in endpoint_names:
    if onnx_model.graph.input[i].name == endpoint_name.split(':')[1]:
      print('-'*60)
      print(onnx_model.graph.input[i])
      onnx_model.graph.input[i].name = endpoint_name.split(':')[0]

for i in range(len(onnx_model.graph.output)):
  for endpoint_name in endpoint_names:
    if onnx_model.graph.output[i].name in endpoint_name.split(':')[1]:
      print('-'*60)
      print(onnx_model.graph.output[i])
      onnx_model.graph.output[i].name = endpoint_name.split(':')[0]

# 修改model_name
print("before modify onnx_model.graph.name is: ", onnx_model.graph.name)
onnx_model.graph.name = model_name
print("after modify onnx_model.graph.name is: ", onnx_model.graph.name)


# 保存模型
onnx.save(onnx_model, export_model)


2.2 修改模型graph的initlizer以及node属性(包括check 模型)

import onnx
import numpy as np
import torch

def create_initializer_tensor(
        name: str,
        tensor_array: np.ndarray,
        data_type: onnx.TensorProto = onnx.TensorProto.FLOAT
) -> onnx.TensorProto:
    # (TensorProto)
    initializer_tensor = onnx.helper.make_tensor(
        name=name,
        data_type=data_type,
        dims=tensor_array.shape,
        vals=tensor_array.flatten().tolist())

    return initializer_tensor

def replace_initializer_node(graph):
    # find initializer node [onnx的常数节点可能是使用initializer表示的,也可能是使用Constant节点表示]
    for initid, initializer in enumerate(graph.initializer):
        print("######%s######" % initid)
        print(initializer)
        print('--------next initializer:')
        pass

    # modify initializer node which we need modify
    #operator: shape
    for initid, initializer in enumerate(graph.initializer):
        if initializer.name == '276':
            del graph.initializer[initid]
            tensor_arr_276 = np.array([1, 1, 85, 6400]).astype(np.int64)
            initializer_tensor_276 = create_initializer_tensor('276', tensor_arr_276, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_276)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '340':
            del graph.initializer[initid]
            tensor_arr_340 = np.array([1, 1, 85, 1600]).astype(np.int64)
            initializer_tensor_340 = create_initializer_tensor('340', tensor_arr_340, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_340)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '404':
            del graph.initializer[initid]
            tensor_arr_404 = np.array([1, 1, 85, 400]).astype(np.int64)
            initializer_tensor_404 = create_initializer_tensor('404', tensor_arr_404, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_404)
            break

    #operator: -----------slice_axes_3--1
    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '279':
            del graph.initializer[initid]
            tensor_arr_279 = np.array([3]).astype(np.int64)
            initializer_tensor_279 = create_initializer_tensor('279', tensor_arr_279, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_279)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '288':
            del graph.initializer[initid]
            tensor_arr_288 = np.array([3]).astype(np.int64)
            initializer_tensor_288 = create_initializer_tensor('288', tensor_arr_288, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_288)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '296':
            del graph.initializer[initid]
            tensor_arr_296 = np.array([3]).astype(np.int64)
            initializer_tensor_296 = create_initializer_tensor('296', tensor_arr_296, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_296)
            break

    #operator: -----------slice_axes_3--2
    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '343':
            del graph.initializer[initid]
            tensor_arr_343 = np.array([3]).astype(np.int64)
            initializer_tensor_343 = create_initializer_tensor('343', tensor_arr_343, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_343)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '352':
            del graph.initializer[initid]
            tensor_arr_352 = np.array([3]).astype(np.int64)
            initializer_tensor_352 = create_initializer_tensor('352', tensor_arr_352, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_352)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '360':
            del graph.initializer[initid]
            tensor_arr_360 = np.array([3]).astype(np.int64)
            initializer_tensor_360 = create_initializer_tensor('360', tensor_arr_360, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_360)
            break

    #operator: -----------slice_axes_3--3
    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '407':
            del graph.initializer[initid]
            tensor_arr_407 = np.array([3]).astype(np.int64)
            initializer_tensor_407 = create_initializer_tensor('407', tensor_arr_407, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_407)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '416':
            del graph.initializer[initid]
            tensor_arr_416 = np.array([3]).astype(np.int64)
            initializer_tensor_416 = create_initializer_tensor('416', tensor_arr_416, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_416)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '424':
            del graph.initializer[initid]
            tensor_arr_424 = np.array([3]).astype(np.int64)
            initializer_tensor_424 = create_initializer_tensor('424', tensor_arr_424, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_424)
            break

    #operator: -----------Add:B
    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '284':
            del graph.initializer[initid]
            nx_284 = 80
            ny_284 = 80
            yv, xv = torch.meshgrid([torch.arange(ny_284), torch.arange(nx_284)])
            grid_284 = torch.stack((xv, yv), 2).view(1, 1, 6400, 2).float()
            tensor_arr_284 = grid_284.numpy().astype(np.int64)
            initializer_tensor_284 = create_initializer_tensor('284', tensor_arr_284, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_284)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '348':
            del graph.initializer[initid]
            nx_348 = 40
            ny_348 = 40
            yv, xv = torch.meshgrid([torch.arange(ny_348), torch.arange(nx_348)])
            grid_348 = torch.stack((xv, yv), 2).view(1, 1, 1600, 2).float()
            tensor_arr_348 = grid_348.numpy().astype(np.int64)
            initializer_tensor_348 = create_initializer_tensor('348', tensor_arr_348, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_348)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '412':
            del graph.initializer[initid]
            nx_412 = 20
            ny_412 = 20
            yv, xv = torch.meshgrid([torch.arange(ny_412), torch.arange(nx_412)])
            grid_412 = torch.stack((xv, yv), 2).view(1, 1, 400, 2).float()
            tensor_arr_412 = grid_412.numpy().astype(np.int64)
            initializer_tensor_412 = create_initializer_tensor('412', tensor_arr_412, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_412)
            break

    return graph

def modify_node_attribute(graph):
    for node_id, node in enumerate(graph.node):
        if node.name == "Transpose_121":
            for attr_id, attr in enumerate(node.attribute):
                print("attr.name:", attr.name)
                print("attr.type:", attr.type)
                # if attr.type == onnx.AttributeProto.AttributeType.INTS:
                #     print("attr.ints:", attr.ints)

                # replace or add attr
                if attr.name == "perm":
                    # attr.ints[4] = {0,2,3,1} #  you can also directly modify origin attr
                    pas_attr = onnx.helper.make_attribute("perm", [0, 1, 3, 2])
                    del node.attribute[attr_id]
                    node.attribute.extend([pas_attr])

        if node.name == "Transpose_177":
            for attr_id, attr in enumerate(node.attribute):
                if attr.name == "perm":
                    # attr.ints[4] = {0,2,3,1} #  you can also directly modify origin attr
                    pas_attr = onnx.helper.make_attribute("perm", [0, 1, 3, 2])
                    del node.attribute[attr_id]
                    node.attribute.extend([pas_attr])
        if node.name == "Transpose_233":
            for attr_id, attr in enumerate(node.attribute):
                if attr.name == "perm":
                    # attr.ints[4] = {0,2,3,1} #  you can also directly modify origin attr
                    pas_attr = onnx.helper.make_attribute("perm", [0, 1, 3, 2])
                    del node.attribute[attr_id]
                    node.attribute.extend([pas_attr])

    return graph




def main():
    #----------0. prepare work-----------
    input_onnx_path = "/root/Desktop/v6n.opt.onnx"
    output_onnx_path = "/root/Desktop/v6n.opt.modify1.onnx"

    #----------1.1 load onnx file-----------
    inpt_onnx_model = onnx.load(input_onnx_path)
    graph = inpt_onnx_model.graph

    # -------1.2 replace initializer node
    graph = replace_initializer_node(graph)
    graph = modify_node_attribute(graph)

    #------------save graph-------------------
    graph = onnx.helper.make_graph(graph.node, graph.name, graph.input, graph.output, graph.initializer)
    info_model = onnx.helper.make_model(graph)
    onnx_model = onnx.shape_inference.infer_shapes(info_model)
    onnx.checker.check_model(onnx_model)
    onnx.save(info_model, output_onnx_path)

if __name__ == '__main__':
    main()
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值