第五章:ONNX 模型的修改与调试 — mmdeploy 1.2.0 文档
ONNX神经网络定义标准本身,探究ONNX模型构造、读取、子模型读取、调试。
1、首先学习ONNX的底层表示方式。
2、ONNX API构造和读取模型。
3、利用ONNX提供子模型提取功能,学习如何调试ONNX模型。
ONNX的底层实现
ONNX的存储格式
ONNX在底层实验protobuf定义的。protobuf全程是protocol buffer,是google提出的一套表示和序列化数据的机制。使用protobuf时,用户需要先写一份数据定义文件,在根据这份定义文件把数据存储在一份二进制文件。可以说,数据定义文件就是数据类,二进制文件就是数据类型的实现。
message Person {#类
required string name = 1;#必须包含
required int32 id = 2;#必须包含
optional string email = 3;#可选
}
#类的实例用protobuf存储成二进制文件,反之用户也可以用二进制文件和对应的数据定义文件,读取出一个person类的实例。
ONNX的protobuf数据定义文件在其开源库中,这些文件定义了神经网络中的模型、节点、张量的数据类型规范;而数据定义文件对应的二进制文件和对应的数据定义文件就是.ONNX文件,每一个.onnx文件按照数据定义规范,存储了一个神经网络的所有相关数据。
ONNX的开源文件中定义了神经网络的模型、节点、张量类。
.onnx文件就是存储了一个神经网络的相关数据。
ONNX 的结构定义
ONNX怎么在protobuf定义文件里描述一个神经网络。
神经网络的本质是一个计算图,计算图的节点是算子,边是参与运算的张量,查看ONNX模型知道,ONNX记录了所有算子节点的属性信息,并把参与运算的张量信息存储在算子节点的输入输出信息中。
ONNX模型结构类图
ONNX模型用ModelProto类表示:版本、创建者、graph计算图结构。
graphProto类包含:输入张量信息、输出张量信息、节点信息。
valueinfoproto类包含:张量名、基本数据类型、形状。
NodeProto类包含了:算子名称、算子输入张量名、算子输出张量名。
#output = a*x+b
ir_version: 8
graph {
node {
input: "a"
input: "x"
output: "c"
op_type: "Mul"#图中包含一个乘法节点
}
node {
input: "c"
input: "b"
output: "output"
op_type: "Add"#图中包含一个加法节点
}
name: "linear_func"
input {#输入张量
name: "a"
type {
tensor_type {
elem_type: 1
shape {
dim {dim_value: 10}
dim {dim_value: 10}
}
}
}
}
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {dim_value: 10}
dim {dim_value: 10}
}
}
}
}
input {
name: "b"
type {
tensor_type {
elem_type: 1
shape {
dim {dim_value: 10}
dim {dim_value: 10}
}
}
}
}
output {#输出张量
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim { dim_value: 10}
dim { dim_value: 10}
}
}
}
}
}
opset_import {version: 15}
读写 ONNX 模型
构造 ONNX 模型
ModelProto
GraphProto
NodeProto
ValueInfoProto
图一般是一个节点集和一个边集表示的,而ONNX巧妙的把边的信息保存在节点信息里面,省去了保存边集的步骤。在ONNX中如果节点的输入名和之前某节点的输出名相同,就默认这两个节点是相连的。
这种边的隐式定义规则,ONNX对节点的输入有一定的要求:一个节点的输入,要么是真个模型的输入,要么是之前某个节点的输出。
onnx.checker.check_model判断一个ONNX模型是否满足ONNX标准。
只需要关注mul和add节点以及它们之间的边c,如果按照[Mul,Add]顺序给出,则遍历Add时输入c可以在Mul的输出中找到。但如果节点以[Add,Mul]顺序给出,那Add就找不到输入边,计算图也无法成功构建出来,[Mul,Add]就是符合有向图的拓扑序,而[Add,Mul]则不满足。
#完全用ONNX的python API构造一个描述线性函数的output = a*x +b的ONNX模型
import onnx
from onnx import helper
from onnx import TensorProto
#helper.make_tensor_value_info构造一个描述张量信息的ValueInfoProto
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])#张量名、张量基本数据类型、张量形状
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])
# 构造算子节点信息NodeProto,可以在helper.make_node中传入算子类型、输入张量名称、输出张量名称
mul = helper.make_node('Mul', ['a', 'x'], ['c'])#先描述c=a*x的乘法节点;定义了输出c
add = helper.make_node('Add', ['c', 'b'], ['output'])#构造output=c+b:定义了输入c,则Mul节点和Add节点是相连的。
graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])#构造计算图GraphProto。需要传入节点名、图名称、
# 输入张量信息,输出张量信息4个参数,这几个参数在上面的NodeProto和valueinfoproto对象按照顺序传入即可。
#make_graph节点参数:计算图的界定必须以拓扑序给出。拓扑序数与有向图相关的数学概念,如果按拓扑序遍历所有节点的话,能保证每个节点的输入都能在
# 之前节点输出中找到,计算图输入张量也是之前的输出。
model = helper.make_model(graph)#把计算图GraphProto封装进模型ModelProto里,
onnx.checker.check_model(model)#检查模型是否正确
print(model)
onnx.save(model, 'linear_func.onnx')#保存模型
import onnxruntime
import numpy as np
sess = onnxruntime.InferenceSession('linear_func.onnx')
a = np.random.rand(10, 10).astype(np.float32)
b = np.random.rand(10, 10).astype(np.float32)
x = np.random.rand(10, 10).astype(np.float32)
output = sess.run(['output'], {'a': a, 'b': b, 'x': x})[0]
assert np.allclose(output, a * x + b)#判断模型是否正确
读取并修改 ONNX 模型
import onnx
#读取onnx模型,与onnx.save对应,写入和读取的都是ModelProto对象
model = onnx.load('linear_func.onnx')
print(model)
#读取图GraphProto、节点NodeProto、张量信息ValueInfoProto
graph = model.graph
node = graph.node
input = graph.input
output = graph.output
print(node)#就是一个列表,列表中对象有属性input(列表)、output、op_type
print(input)
print(output)
#获取第一个节点mul的属性
node_0 = node[0]
node_0_inputs = node_0.input
node_0_outputs = node_0.output
input_0 = node_0_inputs[0]
input_1 = node_0_inputs[1]
output = node_0_outputs[0]
op_type = node_0.op_type
print(input_0)
print(input_1)
print(output)
print(op_type)
import onnx
model = onnx.load('linear_func.onnx')#读取模型,
node = model.graph.node
node[1].op_type = 'Sub'#修改第二个节点类型,将加法变为减法
onnx.checker.check_model(model)
onnx.save(model, 'linear_func_2.onnx')#a*x-b
调试 ONNX 模型
子模型提取
extract从一个给定的ONNX模型中,拿出一个子模型,这个子模型的节点集、边集都是原模型中对应集合的子集
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.convs1 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3))
self.convs2 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3))
self.convs3 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3))
self.convs4 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3))
def forward(self, x):
x = self.convs1(x)
x1 = self.convs2(x)
x2 = self.convs3(x)
x = x1 + x2
x = self.convs4(x)
return x
model = Model()
input = torch.randn(1, 3, 20, 20)
# 边用同名张量表述,边的序号实际上是上一个节点的输出张量序号和后一个节点的输入张量序号,张量序号是pytorch自动生成的
torch.onnx.export(model, input, 'whole_model.onnx')
import onnx
#将计算图中边22到边28子图提取出来,组成一个子模型。源模型路径、输出模型路径、子模型输入边(输入张量)、子模型输出边(输出张量)
onnx.utils.extract_model('whole_model.onnx', 'partial_model.onnx', ['22'], ['28'])
#子模型提取就是把输入边到输出边之间的全部节点都提取出来
特别注意,我这里横跨了torch1.13和torch2.0版本,实际上torch1上面没问题,可能在torch2上面会有问题需要注意的
多一个输出
onnx.utils.extract_model('whole_model.onnx', 'submodel_1.onnx', ['22'], ['27', '31'])
多一个输入
onnx.utils.extract_model('whole_model.onnx', 'submodel_2.onnx', ['22', 'input.1'], ['28'])#多输入了一个input.1节点
#少一个输入(报错),输入节点数不足以计算
onnx.utils.extract_model('whole_model.onnx', 'submodel_3.onnx', ['24'], ['28'])#少一个输入;报错
子模型提取的实现原理:
1、新建一个模型
2、填入给定的输入和输出。
3、把图的所有有向边反向,送输出边开始遍历到输入边。
4、将遍历得到的节点作为子模型节点。
输出 ONNX 中间节点的值
主要是保证深度学习框架模型和ONNX模型精度对齐,只要输出中间节点的值就能定位到精度出现偏差的算子。
可以在保证原有输入输出不变的情况下,添加一些输出,提取出能够输出中间节点的子模型。有点类似于深度监督。
onnx.utils.extract_model('whole_model.onnx', 'more_output_model.onnx', ['input.1'], ['31', '23', '25', '27'])
为了方便调试,将原模型拆分成多个互不相交的子模型。这样调试的时候就可以只对原模型的部分子模型进行调试。将原本复杂的模型拆分成多个简单的子模型。这样就可以在调试的时候,先调试顶层子模型,确认顶层子模型无误后,把它的输出作为后面子模型的输入。
查询节点名称。
pytorch导出ONNX模型由以下几个问题:
1、一旦Pytorch模型改变,ONNX模型边序号也会改变,则每次提取同样子模块式都需要重新去ONNX模型里查询序号,实际上不会使用这么麻烦的调试方法。
2、加入ONNX边序号不发生改变,也不一定能把pytorch代码和ONNX节点对应起来,模型复杂时,ONNX节点含义难以对应。
总结:
1、ONNX使用的是protobuf定义规范和序列化模型。
2、ONNX模型主要是由modelproto、graphproto、nodeproto、valueinfoproto数据类对象组成。
3、onnx.helper.maker_xxx可以构造ONNX模型的数据对象。
4、onnx.save可以保存模型,onnx.load可以读取模型,onnx.checker.check_model可以检查模型是否符合规范。
5、onnx.utils.extract_model可以从元模型中提取部分节点、输入、输出边构建一个新的子模型。
6、利用子模型提取功能,可以输出原ONNX模型中姐姐狗,实现对ONNX模型的调试。