onnx 模型切割掉conv后面的节点,设置输出层名称和最后节点名称一致,设置输出层shape和输出节点一致.

某些模型最后卷积层之后的算子不适合在推理引擎里面跑,切割掉conv后面的算子,在cpu上实现有比较好的性能.
包含:
1.获取onnx中间节点的shape的示例
2.增加onnx模型输出,设置名称,type, shape. 示例
3.编辑onnx模型示例


切割掉绿色部分示例
import onnx
import sys
import json
from onnx import shape_inference, TensorProto

if len(sys.argv) < 2:
    print('Usage: ' + sys.argv[0] + '<onnx_filename>')
    exit(-1)

onnx_file = sys.argv[1]

# 加载ONNX模型
model = onnx.load(onnx_file)

graph = model.graph

outputs = model.graph.output 
if(len(outputs)!=3):
    print("This isn't ScoreBoxKpt model!")
    quit()

output_list=["output0","output1","output2"]

for output in outputs:
    if output.name in score_box_kpt :
        print(f"output name: {output.name}")
    else:
        print("This isn't a fit model!")
        quit()

def getConvList(endName):
    stack=[]
    stack.append(endName)
    convList=[]
    while(len(stack)):
        name=stack.pop()
        for node in graph.node:
            if name in node.output :
                if node.op_type=="Conv":
                    if node.name not in convList :
                        convList.append(node.name)
                else: 
                    for input in node.input:
                        if input not in stack:
                            stack.insert(0, input)
    return convList

Conv0=getConvList(output_list[0])
Conv1=getConvList(output_list[1])
Conv2=getConvList(output_list[2])

def save2json(save_dict, name):
    if len(save_dict) == 0:
        print("this is nothing to save json")
        return None
    with open(name, 'w') as fp:
        #{'a': 'Runoob', 'b': 7}
        json.dump(save_dict, fp, sort_keys=False, indent=4, separators=(',', ': ')) #default=str

save_dict = {output_list[0]:scoreConv,
             output_list[1]:boxConv,
             output_list[2]:kptConv
            }

conv_list=Conv0+Conv1+Conv2

#获取onnx中间节点的shape.
output_dim_dic={}
inferred_onnx_model = shape_inference.infer_shapes(model)
inferred_graph = inferred_onnx_model.graph
inferred_value_info = inferred_graph.value_info
for node in graph.node:
    if node.name in conv_list:
        for value_info in inferred_value_info:
            if value_info.name==node.output[0]:
                output_dim_dic[node.name]=value_info.type.tensor_type;

#删除conv 后面的onnx节点
# Find target node index
for name in conv_list:
    target_node = None
    for node in graph.node:
        if node.name == name:
            target_node=node
            break
    output_names = []
    for output in target_node.output:
        output_names.append(output)

    set1=set(output_names)
    del_node = []

    have_new_del_node = False
    while True:
        have_new_del_node = False
        for node in graph.node:
            if node.name in del_node:
                continue
            set2=set(node.input)
            if set1.intersection(set2): 
                output_names+=node.output         
                set1=set(output_names)
                del_node.append(node.name)
                have_new_del_node = True
        if have_new_del_node == False:
            break

    for node in graph.node:
        if node.name in del_node:
            print(f"1remove node {node.name}")
            model.graph.node.remove(node)

have_new_del_node = False
while True:
    have_new_del_node = False
    for node1 in graph.node:
        if node1.name in conv_list :
            continue
        set1=set(node1.output)
        to_delete =True
        for node2 in graph.node:
            set2=set(node2.input)
            if set1.intersection(set2): 
                to_delete = False
                break
        if to_delete == True:
            print(f"2remove node {node1.name}")
            model.graph.node.remove(node1)
            have_new_del_node=True
    if have_new_del_node == False :
        break

save_output_name=[]
for node in graph.node:
    if node.name in conv_list:
     #增加输出层
        output_info = onnx.helper.ValueInfoProto()
        node.output[0]=node.name
        output_info.name = node.output[0]
        for dim_value in output_dim_dic[node.name].shape.dim:
            output_info.type.tensor_type.shape.dim.extend([dim_value])
        output_info.type.tensor_type.elem_type = TensorProto.FLOAT
        print(output_info)
        graph.output.extend([output_info])
        save_output_name.append(node.output[0])

outputs = model.graph.output 
# 打印输出节点名称
for output in outputs:
    if output.name  in save_output_name :
        continue
    model.graph.output.remove(output)
outputs = model.graph.output 
# 打印输出节点名称
for output in outputs:
    if output.name  in save_output_name :
        continue
    model.graph.output.remove(output)
# Save modified ONNX model
onnx.checker.check_model(model)
onnx.save(model, "backbone.onnx")
save2json(save_dict, 'conv_param.json'
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值