ONNX系列: ONNX模型结构解析

本文详细介绍了ONNX(OpenNeuralNetworkExchange)的背景、结构,特别是神经网络模型的计算图、权重存储和ONNX模型的组成部分。通过PythonAPI展示了如何访问和操作模型的不同层级,为模型部署和优化提供了实用指导。
摘要由CSDN通过智能技术生成

1. ONNX 背景        

ONNX 全称为 Open Neural Network Exchange,是微软提出并推广的一种机器学习模型的开放格式表示。ONNX定义了一组通用的算子集、一系列用于构建深度学习模型的模块以及一种通用的文件格式,使得人工智能开发人员能够将模型与各种框架、工具、运行时和编译器一起使用。ONNX可以理解为是 AI 算法框架与硬件平台之间的桥梁,AI算法研究人员可以使用任意的深度学习框架来设计并训练模型,训练完成后将模型转换成 ONNX 格式来进行存储,模型部署工程师可以针对 ONNX 这一中间格式来针对不同的硬件平台进行运行时设计和优化,从而实现AI模型设计和模型部署的解耦。

2. ONNX 结构分析

想要部署ONNX模型,我们首先需要了解 ONNX 模型的结构。神经网络模型是由 计算图 + 权重 组成的,计算图是一个有向无环的计算流程图,权重则是网络训练好的参数集合。关于 ONNX 更详尽的介绍,可以参考 ONNX 官方文档 onnx/docs at main · onnx/onnx · GitHub

ONNX 模型是利用 ProtoBuf 这一数据结构存储协议来将模型序列化到硬盘上的。一个存储到本地的 .onnx 模型可以以结构化的方式解析成下图所示的各个部分,其中比较重要的部分加粗表示。加粗部分上方是当前结构的名称,下方是当前结构的类型,各个Proto类型的定义可以参考onnx/onnx/onnx.proto at main · onnx/onnx · GitHub

2.1 查看onnx

  • 可以使用 Netron 来可视化查看 ONNX 模型
  • 可以使用 protoc 工具来解析 .onnx 模型文件。命令中 onnx.proto 是 ONNX 官方 repo 中的 Proto 定义。这条命令的含义是将super-resolution-10.onnx作为输入,按照 onnx.proto 定义从中提取 onnx.ModelProto 对象,并将结果重定向到 model.txt。

$ protoc --decode=onnx.ModelProto -I D:\Python\workspace\onnx_learn\onnx\onnx onnx.proto < D:\Python\workspace\onnx_learn\super-resolution-10.onnx > model.txt

2.2 使用 Python 来获取到 ONNX 模型结构

1) onnx model 结构

onnx model 是 ModelProto 类型,是 ONNX 模型最顶层的结构。其所包含的各个成员如下:

属性名

示例值

描述

ir_version

int64

模型的onnx IR版本。

opset_import

OperatorSetId

可用于模型的算子集标识符集合。一个onnx实现中必须包含这个集合中的所有算子,否则将拒接模型。

producer_name

string

生成这个onnx的生产者工具名称。

producer_version

string

这个生产者工具的版本。

domain

string

onnx模型的命名空间,用反向域名命名,和java一样。

model_version

int64

模型本身的版本。

doc_string

string

文档注释,可以是Markdown。

graph

Graph

模型计算图。

metadata_props

map<string,string>

元数据的键值对属性。

training_info

TrainingInfoProto[]

包含训练信息的可选扩展。

functions

FunctionProto[]

模型本地函数的可选列表。

使用Python API获取onnx model各成员

import onnx

model = onnx.load('super-resolution-10.onnx')

print(f"model.ir_version ---> {model.ir_version}")

print(f"model.opset_import ---> {model.opset_import}")

print(f"model.producer_name ---> {model.producer_name}")

print(f"model.producer_version ---> {model.producer_version}")

print(f"model.domain ---> {model.domain}")

print(f"model.model_version ---> {model.model_version}")

print(f"model.doc_string ---> {model.doc_string}")

print(f"model.metadata_props ---> {model.metadata_props}")

print(f"model.training_info ---> {model.training_info}")

print(f"model.functions ---> {model.functions}")

print(f"model.graph ---> {model.graph}")

2) model.graph 结构

model 中最重要的是 graph,类型是 GraphProto。 其所包含的各个成员如下:

属性名

示例值

描述

name

string

模型计算图的名字

node

Node[]

计算图中的算子集合(有向无环图的节点集),按照拓扑排序排列

initializer

Tensor[]

计算图中的initializer,是一个tensor列表,通常存放模型的权重,可以理解为一个常量池。

doc_string

string

文档注释

input

ValueInfo[]

模型计算图的输入tensor列表。

output

ValueInfo[]

模型计算图的输出tensor列表。

value_info

ValueInfo[]

模型计算图除输入输出外中间tensor列表,当使用shape_inference时,推理出来的shape存储到这里,即 Netron 中看到的中间tensor的维度。

metadata_props

map<string,string>

模型计算图的元数据(IR version >= 10)

使用Python API获取 model.graph 各成员

print("-------------------- model.graph.name --------------------")

print(model.graph.name)

print("-------------------- model.graph.node --------------------")

print(model.graph.node)

print("-------------------- model.graph.initializer --------------------")

print(model.graph.initializer)

print("-------------------- model.graph.doc_string --------------------")

print(model.graph.doc_string)

print("-------------------- model.graph.input --------------------")

print(model.graph.input)

print("-------------------- model.graph.output --------------------")

print(model.graph.output)

print("-------------------- model.graph.value_info --------------------")

print(model.graph.value_info)

print("-------------------- model.graph.metadata_props (IR version >= 10) --------------------")

print(model.graph.metadata_props)

3) model.graph.node 结构

model.graph中的 node 是一个节点集列表,其中的每个元素节点均为 NodeProto 类型,所包含的成员如下:

属性名

示例值

描述

name

string

节点的名字

input

string[]

节点的输入列表,相当于计算图的输入边集。

output

string[]

节点的输出列表,相当于计算图的输出边集。

op_type

string

节点的类型,表明该节点的计算逻辑。

domain

string

ONNX中定义的节点集的域。由于 ONNX 是支持第三方拓展内置的算子集的,这个域唯一的指明节点的op_type,类似Java的包管理一样,用域名倒置表示。

attribute

Attribute[]

节点的属性列表。例如Conv节点的kernel shape和padding等。

doc_string

string

文档注释。

overload

string

函数的唯一ID。(added in IR version 10)

metadata_props

map<string,string>

节点的元数据。(IR version >= 10)

使用Python API获取 model.graph 各成员

# 打印 node 各个属性值

print("----------------------- model.graph.node[0].name -----------------------")

print(model.graph.node[0].name)

print("----------------------- model.graph.node[0].input -----------------------")

print(model.graph.node[0].input)

print("----------------------- model.graph.node[0].output -----------------------")

print(model.graph.node[0].output)

print("----------------------- model.graph.node[0].op_type -----------------------")

print(model.graph.node[0].op_type)

print("----------------------- model.graph.node[0].domain -----------------------")

print(model.graph.node[0].domain)

print("----------------------- model.graph.node[0].attribute -----------------------")

print(model.graph.node[0].attribute)

print("----------------------- model.graph.node[0].doc_string -----------------------")

print(model.graph.node[0].doc_string)

print("----------------------- model.graph.node[0].overload -----------------------")

print(model.graph.node[0].overload)

print("----------------------- model.graph.node[0].metadata_props -----------------------")

print(model.graph.node[0].metadata_props)

4) model.graph.initializer 结构

model.graph.initializer 是一个tensor列表,其中的元素类型为TensorProto。Initializer通常保存模型的权重参数,一些输入默认值也可以保存在这里,可以将其理解为一个tensor常量池。initializer每个元素的成员如下:

属性名

示例值

描述

name

string

该tensor的名字

dims

int[]

该tensor的维度

data_type

int

该tensor的数据类型,不同的数值代表不同个的数据类型

raw_data

bytes

该tensor保存的具体数据,二进制形式

doc_string

string

文档注释

使用Python API获取 model.graph 各成员

print("----------------------- model.graph.initializer[0].name -----------------------")

print(model.graph.initializer[0].name)

print("----------------------- model.graph.initializer[0].dims -----------------------")

print(model.graph.initializer[0].dims)

print("----------------------- model.graph.initializer[0].data_type -----------------------")

print(model.graph.initializer[0].data_type)

print("----------------------- model.graph.initializer[0].raw_data -----------------------")

# 二进制表示,打印出来可能会很长

print(model.graph.initializer[0].raw_data)

5)model.graph.input & output & value_info 结构

graph 中的 input、output 和 value_info 均为一个列表,可以使用index进行索引。Input为计算图的所有输入,output是计算图的所有输出,value_info则为计算图中所有中间计算结果tensor的信息。当使用

onnx.shape_inference.infer_shapes()推理所有中间tensor的维度时,这些信息均会保存在value_info中。input、output 和 value_info的每个元素类型为ValueInfoProto,其包含的成员如下

属性名

示例值

描述

name

string

当前值的名字

type

TypeProto

当前值的类型,这其中包含当前值的数据类型和维度

使用Python API获取 model.graph 各成员

# 以第一个input为例

print("----------------------- model.graph.input[0].name -----------------------")

print(model.graph.input[0].name)

print("----------------------- model.graph.input[0].type -----------------------")

print(model.graph.input[0].type)

3. 总结

本文重点解析了 ONNX 模型结构,并演示了如何使用Python定位到ONNX模型各个层面的元素。在得到不同元素之后,我们可以对ONNX模型进行适当的修改,使其更加适配我们的后端运行时,进一步提高推理性能。我们在之后的文章会介绍如何使用 ONNX 官方的 API 来修改ONNX模型。

作者:高通工程师,阮慧源(Huiyuan Ruan)

  • 20
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值