一、前言
为什么要说ONNX,ONNX又是个什么东西,经常要部署神经网络应用的童鞋们可能会ONNX会比较熟悉,我们可能会在某一任务中将Pytorch或者TensorFlow模型转化为ONNX模型(ONNX模型一般用于中间部署阶段),然后再拿转化后的ONNX模型进而转化为我们使用不同框架部署需要的类型。
典型的几个线路:
- Pytorch -> ONNX -> TensorRT
- Pytorch -> ONNX -> TVM
- TF – onnx – ncnn
等等,ONNX相当于一个翻译的作用,这也是为什么ONNX叫做Open Neural Network Exchange。
二、ONNX
好吧,如果光看前言的话,可能第一次接触ONNX的童鞋们可能会有点懵。那我们干脆直接一点。假如我们利用Pytorch训练好一个模型,然后我们将其保存为.pt文件:
比如就叫做model.pt
,这个我们应该很熟悉吧,二进制的模型权重文件,我们可以读取这个文件,相当于预加载了权重信息。
那ONNX呢,利用Pytorch我们可以将model.pt
转化为model.onnx
格式的权重,在这里onnx充当一个后缀名称,model.onnx
就代表ONNX格式的权重文件,这个权重文件不仅包含了权重值,也包含了神经网络的网络流动信息以及每一层网络的输入输出信息和一些其他的辅助信息。
简单拿netron这个工具来可视化(读取ONNX文件)一下:
如图,ONNX中的一些信息都被可视化展示了出来,例如文件格式ONNX v3
,该文件的导出方pytorch 0.4
等等,这些信息都保存在ONNX格式的文件中。
三、什么是Protobuf
ONNX既然是一个文件格式,那么我们就需要一些规则去读取它,或者写入它,ONNX采用的是protobuf这个序列化数据结构协议去存储神经网络权重信息。
Protobuf是个什么东西,如果大家使用过caffe或者caffe2,那么想必可能对Protobuf比较熟悉,因为caffe的模型采用的存储数据结构协议也是Protobuf。
这里简单介绍一些protobuf吧,Protobuf是一种平台无关、语言无关、可扩展且轻便高效的序列化数据结构的协议,可以用于网络通信和数据存储。我们可以通过protobuf自己设计一种数据结构的协议,然后使用各种语言去读取或者写入,通常我们采用的语言就是C++。
关于这个有一篇文章比较好地对此进行了介绍:https://www.ibm.com/developerworks/cn/linux/l-cn-gpb/index.html,本文就不进行详细的介绍了。
四、ONNX的数据格式内容
ONNX中最核心的就是onnx.proto
这个文件,这个文件中定义了ONNX这个数据协议的规则和一些其他信息。
下面是onnx.proto
文件,这个文件可以帮助我们了解ONNX到底包含了一些什么样的信息。
为了方便描述和精简文章内容,这里省略掉了一些不重要的信息,只保留了最关键的部分:
// Copyright (c) Facebook Inc. and Microsoft Corporation. // Licensed under the MIT license. syntax = "proto2"; // ... 省略了一部分 // Nodes // // Computation graphs are made up of a DAG of nodes, which represent what is // commonly called a "layer" or "pipeline stage" in machine learning frameworks. // // For example, it can be a node of type "Conv" that takes in an image, a filter // tensor and a bias tensor, and produces the convolved output. // Node就是神经网络中的一个个操作结点,例如conv、reshape、relu等之类的操作 message NodeProto { repeated string input = 1; // namespace Value repeated string output = 2; // namespace Value // An optional identifier for this node in a graph. // This field MAY be absent in ths version of the IR. optional string name = 3; // namespace Node // The symbolic identifier of the Operator to execute. optional string op_type = 4; // namespace Operator // The domain of the OperatorSet that specifies the operator named by op_type. optional string domain = 7; // namespace Domain // Additional named attributes. // attribute表示这个节点中的一些信息,对于conv结点来说,例如kernel大小、stride大小等 repeated AttributeProto attribute = 5; // A human-readable documentation for this node. Markdown is allowed. optional string doc_string = 6; } // Models // // ModelProto is a top-level file/container format for bundling a ML model and // associating its computation graph with metadata. // // The semantics of the model are described by the associated GraphProto. // Models作为最大的单位,包含了Graph以及一些其他版本信息 message ModelProto { // The version of the IR this model targets. See Version enum above. // This field MUST be present. optional int64 ir_version = 1; // The OperatorSets this model relies on. // All ModelProtos MUST have at least one entry that // specifies which version of the ONNX OperatorSet is // being imported. // // All nodes in the ModelProto's graph will bind against the operator // with the same-domain/same-op_type operator with the HIGHEST version // in the referenced operator sets. repeated OperatorSetIdProto opset_import = 8; // The name of the framework or tool used to generate this model. // This field SHOULD be present to indicate which implementation/tool/framework // emitted the model. optional string producer_name = 2; // The version of the framework or tool used to generate this model. // This field SHOULD be present to indicate which implementation/tool/framework // emitted the model. optional string producer_version = 3; // Domain name of the model. // We use reverse domain names as name space indicators. For example: // `com.facebook.fair` or `com.microsoft.cognitiveservices` // // Together with `model_version` and GraphProto.name, this forms the unique identity of // the graph. optional string domain = 4; // The version of the graph encoded. See Version enum below. optional int64 model_version = 5; // A human-readable documentation for this model. Markdown is allowed. optional string doc_string = 6; // The parameterized graph that is evaluated to execute the model. // 重要部分,graph即包含了网络信息的有向无环图 optional GraphProto graph = 7; // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 14; }; // StringStringEntryProto follows the pattern for cross-proto-version maps. // See https://developers.google.com/protocol-buffers/docs/proto3#maps message StringStringEntryProto { optional string key = 1; optional string value= 2; }; // Graphs // // A graph defines the computational logic of a model and is comprised of a parameterized // list of nodes that form a directed acyclic graph based on their inputs and outputs. // This is the equivalent of the "network" or "graph" in many deep learning // frameworks. // Graphs是最重要的部分,里面包含了模型的构造和模型的权重等一切我们需要的信息 message GraphProto { // The nodes in the graph, sorted topologically. // 经过拓扑排序后的node,也就是结点,每个结点代表模型中的一个操作,例如`conv` repeated NodeProto node = 1; // The name of the graph. optional string name = 2; // namespace Graph // A list of named tensor values, used to specify constant inputs of the graph. // Each TensorProto entry must have a distinct name (within the list) that // also appears in the input list. // initializer存储了模型中的所有参数,也就是我们平时所说的模型权重 repeated TensorProto initializer = 5; // A human-readable documentation for this graph. Markdown is allowed. optional string doc_string = 10; // The inputs and outputs of the graph. repeated ValueInfoProto input = 11; // 模型中所有的输入,包括最开始输入的图像以及每个结点的输入信息 repeated ValueInfoProto output = 12; // Information for the values in the graph. The ValueInfoProto.name's // must be distinct. It is optional for a value to appear in value_info list. repeated ValueInfoProto value_info = 13; // DO NOT USE the following fields, they were deprecated from earlier versions. // repeated string input = 3; // repeated string output = 4; // optional int64 ir_version = 6; // optional int64 producer_version = 7; // optional string producer_tag = 8; // optional string domain = 9; } // Tensors // // A serialized tensor value. message TensorProto { enum DataType { UNDEFINED = 0; // Basic types. FLOAT = 1; // float UINT8 = 2; // uint8_t INT8 = 3; // int8_t UINT16 = 4; // uint16_t INT16 = 5; // int16_t INT32 = 6; // int32_t INT64 = 7; // int64_t STRING = 8; // string BOOL = 9; // bool // IEEE754 half-precision floating-point format (16 bits wide). // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. FLOAT16 = 10; DOUBLE = 11; UINT32 = 12; UINT64 = 13; COMPLEX64 = 14; // complex with float32 real and imaginary components COMPLEX128 = 15; // complex with float64 real and imaginary components // Non-IEEE floating-point format based on IEEE754 single-precision // floating-point number truncated to 16 bits. // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. BFLOAT16 = 16; // Future extensions go here. } // The shape of the tensor. repeated int64 dims = 1; // The data type of the tensor. // This field MUST have a valid TensorProto.DataType value optional int32 data_type = 2; // Defines a tensor shape. A dimension can be either an integer value // or a symbolic variable. A symbolic variable represents an unknown // dimension. message TensorShapeProto { message Dimension { oneof value { int64 dim_value = 1; string dim_param = 2; // namespace Shape }; // Standard denotation can optionally be used to denote tensor // dimensions with standard semantic descriptions to ensure // that operations are applied to the correct axis of a tensor. // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition // for pre-defined dimension denotations. optional string denotation = 3; }; repeated Dimension dim = 1; // 该Tensor的维数 } // Operator Sets // // OperatorSets are uniquely identified by a (domain, opset_version) pair. message OperatorSetIdProto { // The domain of the operator set being identified. // The empty string ("") or absence of this field implies the operator // set that is defined as part of the ONNX specification. // This field MUST be present in this version of the IR when referring to any other operator set. optional string domain = 1; // The version of the operator set being identified. // This field MUST be present in this version of the IR. optional int64 version = 2; }
五、ONNX的版本
由Pytorch-1.0导出的ONNX模型的版本号为:
ONNX IR version: 0.0.3 Opset version: 9 Producer name: pytorch Producer version: 0.4
ONNX的版本也会一直在更新,支持的操作数也会越来越多,操作即类似于conv、pool、relu等神经网络层一类的东西,在onnx中这些都会变形为node。
六、示例
安装onnx很简单,我们只需要pip onnx
即可,这样的同时也将protobuf安装。
我们可以通过刚才安装的onnx去读取这个url中的.onnx文件:https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/93672b029103648953c4e5ad3ac3aadf346a4cdc/super_resolution_0.2.onnx
在debug界面显示读取的数据交互格式信息,可以与之前介绍的onnx.proto
文件相匹配,下面的信息即表示了一个神经网络的全部结构以及每个层(node)的操作层等各种信息:
ir_version: 1 producer_name: "pytorch" producer_version: "0.2" domain: "com.facebook" graph { node { input: "1" input: "2" output: "11" op_type: "Conv" attribute { name: "kernel_shape" ints: 5 ints: 5 } attribute { name: "strides" ints: 1 ints: 1 } attribute { name: "pads" ints: 2 ints: 2 ints: 2 ints: 2 } attribute { name: "dilations" ints: 1 ints: 1 } attribute { name: "group" i: 1 } } node { input: "11" input: "3" output: "12" op_type: "Add" attribute { name: "broadcast" i: 1 } attribute { name: "axis" i: 1 } } node { input: "12" output: "13" op_type: "Relu" } node { input: "13" input: "4" output: "15" op_type: "Conv" attribute { name: "kernel_shape" ints: 3 ints: 3 } attribute { name: "strides" ints: 1 ints: 1 } attribute { name: "pads" ints: 1 ints: 1 ints: 1 ints: 1 } attribute { name: "dilations" ints: 1 ints: 1 } attribute { name: "group" i: 1 } } node { input: "15" input: "5" output: "16" op_type: "Add" attribute { name: "broadcast" i: 1 } attribute { name: "axis" i: 1 } } node { input: "16" output: "17" op_type: "Relu" } node { input: "17" input: "6" output: "19" op_type: "Conv" attribute { name: "kernel_shape" ints: 3 ints: 3 } attribute { name: "strides" ints: 1 ints: 1 } attribute { name: "pads" ints: 1 ints: 1 ints: 1 ints: 1 } attribute { name: "dilations" ints: 1 ints: 1 } attribute { name: "group" i: 1 } } node { input: "19" input: "7" output: "20" op_type: "Add" attribute { name: "broadcast" i: 1 } attribute { name: "axis" i: 1 } } node { input: "20" output: "21" op_type: "Relu" } node { input: "21" input: "8" output: "23" op_type: "Conv" attribute { name: "kernel_shape" ints: 3 ints: 3 } attribute { name: "strides" ints: 1 ints: 1 } attribute { name: "pads" ints: 1 ints: 1 ints: 1 ints: 1 } attribute { name: "dilations" ints: 1 ints: 1 } attribute { name: "group" i: 1 } } node { input: "23" input: "9" output: "24" op_type: "Add" attribute { name: "broadcast" i: 1 } attribute { name: "axis" i: 1 } } node { input: "24" output: "25" op_type: "Reshape" attribute { name: "shape" ints: 1 ints: 1 ints: 3 ints: 3 ints: 224 ints: 224 } } node { input: "25" output: "26" op_type: "Transpose" attribute { name: "perm" ints: 0 ints: 1 ints: 4 ints: 2 ints: 5 ints: 3 } } node { input: "26" output: "27" op_type: "Reshape" attribute { name: "shape" ints: 1 ints: 1 ints: 672 ints: 672 } } name: "torch-jit-export" initializer { dims: 64 dims: 1 dims: 5 dims: 5 data_type: FLOAT name: "2" raw_data: "\034
七、利用Pytorch导出
其中model为pytorch的模型,example为输入,export_params=True
代表连带参数一并输出。
model = test_model() state = torch.load('test.pth') model.load_state_dict(state['model'], strict=True) example = torch.rand(1, 3, 128, 128) torch_out = torch.onnx.export(model, example, "test.onnx", verbose=True, export_params=True )
八、已知一个由Pytorch导出ONNX模型的bug
使用view模拟flatten操作,但是导出的onnx的operator与预想的不一致:https://pytorch.org/docs/stable/onnx.html#supported-operators
# output = input.view(input.size(0), -1) 使用这个Pytorch操作层导出的onnx.flatten是错误的
output = input.view([int(input.size(0)), -1]) # 一种暂时的解决方法
output = input.flatten(1) # 正确的做法
相关问题: https://github.com/pytorch/pytorch/issues/13963
解决方法:https://github.com/pytorch/pytorch/pull/16240
九、参考链接
https://www.ibm.com/developerworks/cn/linux/l-cn-gpb/index.html
https://github.com/onnx/onnx/blob/master/docs/IR.md