1 onnx简介
…
2 模型解析
因为onnx是保存的模型的计算图和模型的预训练参数,这里先解析onnx的计算图,对整个计算流程有一个认识。
计算图就是多个 op 的组合,每个 op 都有输入,输出,然后将所有的 op 结合起来,形成一个 graph
2.1 onnx protobuf定义解析
整个定义是ModelProto -> GraphProto - > NodeProto
主要就是这三个部分
最外层是ModelProto
,记录一些模型信息:ir版本,来自pytorch/tensorflow,… , 和 GraphProto
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
int64 ir_version = 1;
// The OperatorSets this model relies on.
// All ModelProtos MUST have at least one entry that
// specifies which version of the ONNX OperatorSet is
// being imported.
//
// All nodes in the ModelProto's graph will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets.
repeated OperatorSetIdProto opset_import = 8;
// The name of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_name = 2;
// The version of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
string producer_version = 3;
// Domain name of the model.
// We use reverse domain names as name space indicators. For example:
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
//
// Together with `model_version` and GraphProto.name, this forms the unique identity of
// the graph.
string domain = 4;
// The version of the graph encoded. See Version enum below.
int64 model_version = 5;
// A human-readable documentation for this model. Markdown is allowed.
string doc_string = 6;
// The parameterized graph that is evaluated to execute the model.
GraphProto graph = 7;
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
};
GraphProto才是核心,里面主要包含:1. TensorProto initializer
保存const tensor + 预训练的参数。2. NodeProto node
保存每个op 输入,输出 tensor 名字。
message GraphProto {
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;
// The name of the graph.
string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
// Each TensorProto entry must have a distinct name (within the list) that
// also appears in the input list.
repeated TensorProto initializer = 5;
// A human-readable documentation for this graph. Markdown is allowed.
string doc_string = 10;
// The inputs and outputs of the graph.
repeated ValueInfoProto input = 11;
repeated ValueInfoProto output = 12;
// Information for the values in the graph. The ValueInfoProto.name's
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
// DO NOT USE the following fields, they were deprecated from earlier versions.
// repeated string input = 3;
// repeated string output = 4;
// optional int64 ir_version = 6;
// optional int64 producer_version = 7;
// optional string producer_tag = 8;
// optional string domain = 9;
}
NodeProto
message NodeProto {
repeated string input = 1; // namespace Value
repeated string output = 2; // namespace Value
// An optional identifier for this node in a graph.
// This field MAY be absent in ths version of the IR.
string name = 3; // namespace Node
// The symbolic identifier of the Operator to execute.
string op_type = 4; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
string domain = 7; // namespace Domain
// Additional named attributes.
repeated AttributeProto attribute = 5;
// A human-readable documentation for this node. Markdown is allowed.
string doc_string = 6;
}
所以整个计算图的node的输入,来自于node.input,node的输出记录在node.output,但是有些node的输入为const tensor,在graph.initializer
中。
总结起来就是:[inputs - (outputs + graph.out)] in initializer
python解析代码
import onnx
def onnx_parser(onnx_path):
onnx_model = onnx.load(onnx_path)
graph = onnx_model.graph
onnx_initial = []
for init in graph.initializer:
onnx_initial.append(init.name)
inputs = []
outputs = []
for node in graph.node:
# input
for input in node.input:
inputs.append(input)
# output
for output in node.output:
outputs.append(output)
print("len inputs: ", len(inputs))
print("len outputs: ", len(outputs))
print("len inputs: ", len(set(inputs)))
print("len outputs: ", len(set(outputs)))
union = set(inputs) & set(outputs)
sub1 = set(inputs) - set(outputs) # inputs = outputs + initial + graph.out
for name in sub1:
if name not in onnx_initial:
print("{} not in onnx initial tensor".format(name))
# assert name == graph.input
sub2 = set(outputs) - set(inputs) # 等于graph.out
# assert list(sub2)[0] == graph.output
t = 1
if __name__ == '__main__':
onnx_path = 'mqbench_qmodel_for_tengine.onnx'
ret = onnx_parser(onnx_path)