文章目录
1 ONNX模型API构造与代码检查
参考博客:https://zhuanlan.zhihu.com/p/516920606
1 构造描述张量信息的对象ValueInfoProto
import onnx
from onnx import helper
from onnx import TensorProto
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])
2 构造算子节点信息NodeProto
mul = helper.make_node('Mul', ['a', 'x'], ['c'])
add = helper.make_node('Add', ['c', 'b'], ['output'])
3 构造计算图GraphProto
graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])
4 封装计算图
用 helper.make_model
把计算图 GraphProto
封装进模型 ModelProto
model = helper.make_model(graph)
5 检查代码
onnx.checker.check_model(model)
print(model)
onnx.save(model, 'linear_func.onnx')
2 ONNX Python API 构造模型完整代码
import onnx
from onnx import helper
from onnx import TensorProto
# input and output
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])
# Mul
mul = helper.make_node('Mul', ['a', 'x'], ['c'])
# Add
add = helper.make_node('Add', ['c', 'b'], ['output'])
# graph and model
graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])
model = helper.make_model(graph)
# save model
onnx.checker.check_model(model)
print(model)
onnx.save(model, 'linear_func.onnx')
3 ONNX模型修改
参考博客:onnx模型图优化/模型修改
1 修改算子属性attrs
import onnx
onnx_model = onnx.load("models/conv_2d_backprop_static.onnx")
graph = onnx_model.graph
# print("graph_input:", graph.input)
# print("graph_output:", graph.output)
# print("nodes:", graph.node)
for node_id, node in enumerate(graph.node):
# print node info
# print(node)
print("name:", node.name)
print("op_type:", node.op_type)
# print("input:", node.input)
# print("output:", node.output)
# print("attribute:", node.attribute)
if node.name == "ConvTranspose__9":
for attr_id, attr in enumerate(node.attribute):
# print attr info
# print("attr:", attr)
print("attr.name:", attr.name)
print("attr.type:", attr.type)
if attr.type == onnx.AttributeProto.AttributeType.INTS:
print("attr.ints:", attr.ints)
# if attr.type == onnx.AttributeProto.AttributeType.STRING:
# print("attr.s:", attr.s)
# replace or add attr
if attr.name == "pads":
# attr.ints[2] = 0 # you can also directly modify origin attr
pas_attr = onnx.helper.make_attribute("pads", [0, 1, 0, 2])
del node.attribute[attr_id]
node.attribute.extend([pas_attr])
print("node new attribute:", node.attribute)
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, "modified.onnx")
2 创建模型
import onnx
import numpy as np
from onnx.helper import make_node, make_graph, make_tensor_value_info, make_model, make_opsetid
# use -1 for dynamic shape
inputs = [onnx.helper.make_tensor_value_info(name="input1", elem_type=onnx.TensorProto.FLOAT, shape=(-1, 16)),
onnx.helper.make_tensor_value_info(name="input2", elem_type=onnx.TensorProto.FLOAT, shape=(-1, 16))]
outputs = [onnx.helper.make_tensor_value_info(name="output", elem_type=onnx.TensorProto.FLOAT, shape=(-1, 16))]
const_shape = (1, 16)
const_values = np.random.uniform(-10, 10, size=const_shape).astype("float32")
const_node0 = onnx.helper.make_node(
op_type="Constant",
inputs=[],
outputs=["const1:0"],
name="const1",
value=onnx.helper.make_tensor(name='const1',
data_type=onnx.TensorProto.FLOAT,
dims=const_values.shape,
vals=const_values.reshape(-1)))
add_node0 = onnx.helper.make_node(op_type="Add", inputs=["input1", "input2"], outputs=["add1:0"], name="add1")
add_node1 = onnx.helper.make_node(op_type="Add", inputs=["add1:0", "const1:0"], outputs=["output"], name="add2")
# Nodes in a graph must be topologically sorted
nodes = [
const_node0,
add_node0,
add_node1,
]
graph = onnx.helper.make_graph(nodes=nodes, name="add_test", inputs=inputs, outputs=outputs)
# you can also use graph.node.insert(idx, add_node0) to insert add_node0 before node at idx
onnx_model = onnx.helper.make_model(graph, opset_imports=[make_opsetid(domain="", version=11)])
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, "add_model.onnx")
3 图中插入新节点
import onnx
import numpy as np
def create_pad(node, out_idx, pad_size):
"""Create pad and pad size node"""
node_out_name = node.output[out_idx]
node_out_new_name = node_out_name + "_padded"
pad_size_node_name = node_out_name + "_pad_size"
pad_node_name = node_out_name + "_pad"
node.output[out_idx] = node_out_new_name
pad_size = np.array(pad_size).astype("int64")
pad_size_node = onnx.helper.make_node(
op_type="Constant",
inputs=[],
outputs=[pad_size_node_name],
name=pad_size_node_name,
value=onnx.helper.make_tensor(name="const_value",
data_type=onnx.TensorProto.INT64,
dims=pad_size.shape,
vals=pad_size.reshape(-1)))
pad_node = onnx.helper.make_node(op_type="Pad",
inputs=[node_out_new_name, pad_size_node_name],
outputs=[node_out_name],
name=pad_node_name)
pas_attr = onnx.helper.make_attribute("mode", 'reflect')
pad_node.attribute.extend([pas_attr])
return [pad_size_node, pad_node]
def get_node(onnx_model, node_name):
for node_id, node in enumerate(onnx_model.graph.node):
if node.name == node_name:
return node, node_id
return None, 0
model_path = "modified.onnx"
onnx_model = onnx.load(model_path)
conv2d_trans, node_id = get_node(onnx_model, "ConvTranspose__9")
pad_size = [0, 0, 0, 0, 0, 0, 1, 0]
new_nodes = create_pad(conv2d_trans, 0, pad_size)
# insert new nodes after the target node with correct topological order
for new_node in reversed(new_nodes):
onnx_model.graph.node.insert(node_id + 1, new_node)
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, "test_add_pad.onnx")
4 删除节点
import onnx
model_path = "bert_model_int32.onnx"
out_model_path = "bert_model_int32_fp32.onnx"
onnx_model = onnx.load(model_path)
graph = onnx_model.graph
# demo for remove node with single input and output
in_rename_map = {}
for node_id, node in enumerate(graph.node):
if node.name == "Cast_1185":
in_name = node.input[0]
out_name = node.output[0]
in_rename_map = {out_name: in_name}
del graph.node[node_id]
break
for node_id, node in enumerate(graph.node):
for in_id, in_name in enumerate(node.input):
if in_name in in_rename_map:
node.input[in_id] = in_rename_map[in_name]
onnx.save(onnx_model, out_model_path)
5 提取子图
import onnx
def get_input_nodes(graph, node):
"""
can be accelerated by using a map to store the info
"""
in_names = node.input
in_nodes = []
for node_id, node in enumerate(graph.node):
for _out_name in node.output:
if _out_name in in_names:
in_nodes.append(node)
break
return in_nodes
def get_node_cluster(graph, queue):
visited_nodes = []
while queue:
cur_node = queue.pop(0)
visited_nodes.append(cur_node)
in_nodes = get_input_nodes(graph, cur_node)
for node in in_nodes:
if (node not in queue) and (node not in visited_nodes):
queue.append(node)
return visited_nodes
def get_output_nodes(graph, output_names):
queue = []
for node_id, node in enumerate(graph.node):
for _out_name in node.output:
if _out_name in output_names:
queue.append(node)
break
return queue
model_path = "model.onnx"
out_model_path = "model_extract.onnx"
output_names = [
"conv1/7x7_s2/bn/sc_2",
]
onnx_model = onnx.load(model_path)
graph = onnx_model.graph
queue = get_output_nodes(graph, output_names)
visited_nodes = get_node_cluster(graph, queue)
sorted_nodes = []
for node in graph.node:
if node in visited_nodes:
sorted_nodes.append(node)
all_in_names = []
[all_in_names.extend(node.input) for node in visited_nodes]
initializer_names = [_init.name for _init in graph.initializer]
# get true input without initializer
inputs = []
for _input in graph.input:
if _input.name in all_in_names:
inputs.append(_input)
# create output tensor
outputs = [onnx.helper.make_tensor_value_info(
name=_name, elem_type=onnx.TensorProto.FLOAT, shape=[-1, -1, -1, -1])for _name in output_names]
graph_n = onnx.helper.make_graph(nodes=sorted_nodes, name="extracted_graph", inputs=inputs, outputs=outputs)
# copy initializer
for initializer in graph.initializer:
if initializer.name in all_in_names:
graph_n.initializer.append(initializer)
onnx_model_n = onnx.helper.make_model(graph_n)
for item in dir(onnx_model.graph):
print(item)
# should not use onnx.helper.make_model(graph_n), since the opset is not the same with origin onnx_model
onnx_model.graph.CopyFrom(graph_n)
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, out_model_path)
6 修改输入输出名称
import onnx
onnx_model = onnx.load('model.onnx')
endpoint_names = ['image_tensor:0', 'output:0']
for i in range(len(onnx_model.graph.node)):
for j in range(len(onnx_model.graph.node[i].input)):
if onnx_model.graph.node[i].input[j] in endpoint_names:
print('-'*60)
print(onnx_model.graph.node[i].name)
print(onnx_model.graph.node[i].input)
print(onnx_model.graph.node[i].output)
onnx_model.graph.node[i].input[j] = onnx_model.graph.node[i].input[j].split(':')[0]
for j in range(len(onnx_model.graph.node[i].output)):
if onnx_model.graph.node[i].output[j] in endpoint_names:
print('-'*60)
print(onnx_model.graph.node[i].name)
print(onnx_model.graph.node[i].input)
print(onnx_model.graph.node[i].output)
onnx_model.graph.node[i].output[j] = onnx_model.graph.node[i].output[j].split(':')[0]
for i in range(len(onnx_model.graph.input)):
if onnx_model.graph.input[i].name in endpoint_names:
print('-'*60)
print(onnx_model.graph.input[i])
onnx_model.graph.input[i].name = onnx_model.graph.input[i].name.split(':')[0]
for i in range(len(onnx_model.graph.output)):
if onnx_model.graph.output[i].name in endpoint_names:
print('-'*60)
print(onnx_model.graph.output[i])
onnx_model.graph.output[i].name = onnx_model.graph.output[i].name.split(':')[0]
onnx.save(onnx_model, 'model_mod.onnx')
7 修改输入shape
import onnx
model_path = "bert_model.onnx"
out_model_path = "bert_model_int32.onnx"
onnx_model = onnx.load(model_path)
graph = onnx_model.graph
print("graph_input:", graph.input)
print("graph_output:", graph.output)
for input in graph.input:
if input.type.tensor_type.elem_type == onnx.TensorProto.DataType.INT64:
input.type.tensor_type.elem_type = onnx.TensorProto.DataType.INT32
print("updated graph_input:", onnx_model.graph.input)
print("updated graph_output:", onnx_model.graph.output)
onnx.save(onnx_model, out_model_path)