直接使用tvm自带的onnx模型(tests/micro/testdata/mnist/mnist-8.onnx)来debug下TVM的模型编译流程,代码如下
import onnx
import numpy as np
import tvm
from tvm import te
import tvm.relay as relay
from PIL import Image
onnx_model = onnx.load('mnist-8.onnx')
image_path = '0.png'
img = Image.open(image_path).resize((28, 28))
img = img.convert('L')
img = np.array(img, dtype=np.float32)
x = img.reshape((1, 1, 28, 28))
target = "llvm"
input_name = "Input3"
shape_dict = {input_name: x.shape}
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
with tvm.transform.PassContext(opt_level=1):
intrp = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target)
######################################################################
# Execute on TVM
# ---------------------------------------------
dtype = "float32"
tvm_output = intrp.evaluate()(tvm.nd.array(x.astype(dtype)), **params).asnumpy()
print(np.argmax(tvm_output))
这里relay.frontend.from_onnx(onnx_model, shape_dict)是将onnx模型转换为TVM可以识别的Graph IR。要理解这一流程,需要对onnx模型定义有基础的了解。
1.onnx模型文件简介
onnx模型的数据定义参见(https://github.com/onnx/onnx/blob/main/onnx/onnx.proto)onnx.proto文件。onnx模型的数据类型有如下几种:
ModelProto
GraphProto
NodeProto
ValueInfoProto
TensorProto
AttributeProto
加载了一个onnx模型之后获得的就是一个ModelProto,它包含了一些版本信息,生产者信息和一个GraphProto。 在GraphProto里面又包含了四个repeated数组,它们分别是node(NodeProto类型),input(ValueInfoProto类型),output(ValueInfoProto类型)和initializer(TensorProto类型)。其中node中存放了模型中所有的计算节点,input存放了模型的输入节点,output存放了模型中所有的输出节点,initializer存放了模型的所有权重参数。
ONNX的每个计算节点都会有input和output两个数组,这两个数组是string类型,通过input和output的指向关系来构建出一个深度学习模型的网络结构。
这里要注意一下, GraphProto中的input数组不仅包含我们一般理解中的图片输入的那个节点,还包含了模型中所有的权重。例如,Conv层里面的W权重实体是保存在initializer中的,那么相应的会有一个同名的输入在input中。其背后的逻辑应该是把权重也看成模型的输入,并通过initializer中的权重实体来对这个输入做初始化,即一个赋值的过程。
initializer和input中都有网络权重,那么它们有什么区别呢?initializer是TensorProto类型数组
message TensorProto {
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
BFLOAT16 = 16;
}
repeated int64 dims = 1;
optional int32 data_type = 2;
message Segment {
optional int64 begin = 1;
optional int64 end = 2;
}
optional Segment segment = 3;
repeated float float_data = 4 [packed = true];
repeated int32 int32_data = 5 [packed = true];
repeated bytes string_data = 6;
repeated int64 int64_data = 7 [packed = true];
optional string name = 8; // namespace Value
optional string doc_string = 12;
optional bytes raw_data = 9;
repeated StringStringEntryProto external_data = 13;
enum DataLocation {
DEFAULT = 0;
EXTERNAL = 1;
}
optional DataLocation data_location = 14;
repeated double double_data = 10 [packed = true];
repeated uint64 uint64_data = 11 [packed = true];
}
而input数组是ValueInfoProto类型:
message ValueInfoProto {
optional string name = 1; // namespace Value
optional TypeProto type = 2;
optional string doc_string = 3;
}
猜测initializer里面是记录的权重完整数据,而input数组中只是记录了权重的名字和类型等描述信息。
最后,每个计算节点中还包含了一个AttributeProto数组,用来描述该节点的属性,比如Conv节点或者说卷积层的属性包含group,pad,strides等等。
ONNX是把一个网络的每一层或者说一个算子当成节点node,使用这些Node去构建一个Graph,即一个网络。将Graph和其它的生产者信息,版本信息等合并在一起生成一个Model,也即是最终的ONNX模型文件。
2. from_onnx流程
relay.frontend.from_onnx就是读入onnx模型,按照前述onnx模型数据结构,解析模型的initializer、input、nodes和output,将算子转换为tvm relay ir算子和表达式,最终得到整个模型的tvm relay IRModule。from_onnx定义在python/tvm/relay/frontend/onnx.py:
def from_onnx(
model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None
):
global ONNX_DEFAULT_CONFIGS
if convert_config is not None:
ONNX_DEFAULT_CONFIGS.update(convert_config)
try:
import onnx
if hasattr(onnx.checker, "check_model"):
# try use onnx's own model checker before converting any model
try:
onnx.checker.check_model(model)
except Exception as e: # pylint: disable=c-extension-no-member, broad-except
# the checker is a bit violent about errors, so simply print warnings here
warnings.warn(str(e))
except ImportError:
pass
g = GraphProto(shape, dtype, freeze_params)
graph = model.graph
try:
opset_in_model = model.opset_import[0].version if model.opset_import else 1
except AttributeError:
opset_in_model = 1
if opset is None:
opset = opset_in_model
elif opset < opset_in_model:
warnings.warn(
""
f"You are overwritting original opset ver = {opset_in_model} by lower ver = {opset}. "
f"That might cause model conversion errors."
)
# Use the graph proto as a scope so that ops can access other nodes if needed.
with g:
mod, params = g.from_onnx(graph, opset)
return mod, params
我们可以只用关注下面几行代码
...
g = GraphProto(shape, dtype, freeze_params)
graph = model.graph
...
with g:
mod, params = g.from_onnx(graph, opset)
return mod, params
这里生成一个TVM的GraphProto实例g, 然后将传入的onnx模型graph传入GraphProto的from_onnx方法,得到模型的TVM中间表示mod(tvm.IRModule类型)和参数params。GraphProto的from_onnx的流程如下:
def from_onnx(self, graph, opset, get_output_expr=False):
"""Construct Relay expression from ONNX graph.
Onnx graph is a python protobuf object.
The companion parameters will be handled automatically.
However, the input names from onnx graph is vague, mixing inputs and
network weights/bias such as "1", "2"...
For convenience, we rename the `real` input names to "input_0",
"input_1"... And renaming parameters to "param_0", "param_1"...
Parameters
----------
graph : onnx protobuf object
The loaded onnx graph
opset : opset version
get_output_expr: bool
If set to true, this conversion will return each output expression rather
than a packaged module. This can be useful when converting subgraphs to
relay.
Returns
-------
mod : tvm.IRModule
The returned relay module
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
self.opset = opset
# parse network inputs to relay, aka parameters
# onnx的initializer存放了模型的每个节点的权重参数
for init_tensor in graph.initializer:
# 注意这里的init_tensor.name并不是网络算子的名字,而是权重的命令.例如一个resnet18网络中有
# 卷积节点名字为resnetv15_stage1_conv2_fwd,这个节点的权重参数W名字为186,偏置参数B名字为188
# graph.initializer里面就会有两个对应的节点,name分别为186和188
if not init_tensor.name.strip():
raise ValueError("Tensor's name is required.")
# 将权重矩阵转换为tvm的nD-array
array = self._parse_array(init_tensor)
if self._freeze_params:
# 如果设置了feeze_params参数,则所有的参数都认为是常量(针对不定参数模型的优化?)
self._nodes[init_tensor.name] = _expr.const(array)
else:
# 将解析的参数记录到参数表中
self._params[init_tensor.name] = array
# 在node表中加入该参数
self._nodes[init_tensor.name] = new_var(
init_tensor.name,
shape=self._params[init_tensor.name].shape,
dtype=self._params[init_tensor.name].dtype,
)
# GraphProto中的input数组不仅包含模型的输入,还包含了模型中各节点的权重,
# 也就是将外部的图片数据输入和网络自身的权重参数都当作输入
for i in graph.input:
# from onnx v0.2, GraphProto.input has type ValueInfoProto,
# and the name is 'i.name'
# 获取参数的name, shape, type等信息
i_name, i_shape, d_type, i_shape_name = get_info(i)
# 如果是节点参数,则在前面graph.initializer的处理中已经在参数表中添加了对应的节点
if i_name in self._params:
# i is a param instead of input
self._num_param += 1
self._params[i_name] = self._params.pop(i_name)
self._nodes[i_name] = new_var(
i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype
)
elif i_name in self._nodes:
continue
# 如果是模型的输入
else:
self._num_input += 1
self._input_names.append(i_name)
# self._shape是用户在调用from_onnx的时候传入的模型输入shape参数
if i_name in self._shape:
i_shape = self._shape[i_name]
else:
# 模型的输入shape有不定项
if "?" in str(i_shape):
warning_msg = (
"Input %s has unknown dimension shapes: %s. "
"Specifying static values may improve performance"
% (i_name, str(i_shape_name))
)
warnings.warn(warning_msg)
if isinstance(self._dtype, dict):
dtype = self._dtype[i_name] if i_name in self._dtype else d_type
else:
dtype = d_type
# 在nodes表中加入输入节点
self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
#记录模型的输入node
self._inputs[i_name] = self._nodes[i_name]
# Only check user inputs in the outer-most graph scope.
if self._old_manager is None:
assert all(
[name in self._input_names for name in self._shape.keys()]
), "User specified the shape for inputs that weren't found in the graph: " + str(
self._shape
)
# get list of unsupported ops
# 获取onnx算子和tvm算子的映射表
convert_map = _get_convert_map(opset)
unsupported_ops = set()
for node in graph.node:
op_name = node.op_type
if (
op_name not in convert_map
and op_name != "Constant"
and op_name not in _identity_list
):
# 如果算子不在映射表中,也不在_identity_list表中,则认为当前算子是TVM不支持的
unsupported_ops.add(op_name)
if unsupported_ops:
msg = "The following operators are not supported for frontend ONNX: "
msg += ", ".join(unsupported_ops)
# 如果有不支持的算子,则转换失败
raise tvm.error.OpNotImplemented(msg)
# construct nodes, nodes are stored as directed acyclic graph
# 处理onnx模型文件中的node数据, onnx node的proto类型为NodeProto,参见https://gitee.com/mirrors/ONNX/blob/main/onnx/onnx.proto
# 其中input/output都是string类型
for node in graph.node:
# 获取算子类型和属性
op_name = node.op_type
attr = self._parse_attr(node.attribute)
# Create and populate input list.
# 创建一个(算子)输入实例
inputs = onnx_input()
# 获取当前(onnx)节点的所有输入(name)
for i in node.input:
if i != "":
inputs.append(self._nodes[self._renames.get(i, i)])
else:
# 有些输入没使用?
inputs.append(None)
i_name = self._parse_value_proto(node)
# 获取onnx节点的输出,为string类型,是输出的name
node_output = self._fix_outputs(op_name, node.output)
# 记录onnx节点的属性
attr["tvm_custom"] = {}
attr["tvm_custom"]["name"] = i_name
attr["tvm_custom"]["num_outputs"] = len(node_output)
"""
将onnx 算子节点转换为对应的tvm表示.
例如onnx模型中一个卷积算子节点,返回的op类型为tvm.relay.expr.Call, 数据为:
free_var %data: Tensor[(1, 3, 224, 224), float32];
free_var %v174: Tensor[(64, 3, 7, 7), float32];
%0 = nn.conv2d(%data, %v174, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]);
free_var %v176: Tensor[(64), float32];
nn.bias_add(%0, %v176)
"""
op = self._convert_operator(op_name, inputs, attr, opset)
# 获取算子输出个数
if not isinstance(op, _expr.TupleWrapper):
outputs_num = 1
else:
outputs_num = len(op)
if outputs_num == 1:
op = fold_constant(op)
else:
op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))
# 如果算子有多个输出
if outputs_num > 1:
# ONNX supports optional outputs for some nodes.
# This block searches for missing outputs in the ONNX graph
# and removes any unneeded ops
# 获取节点的有效输出.这里认为当算子有多个输出并且某个输出没有被使用时,对应的name为空字符串
valid_outputs = [False] * outputs_num
for i, output in enumerate(node_output):
if output != "":
valid_outputs[i] = True
# If we have outputs ONNX isn't expecting, we need to drop them
# 如果节点有输出是无效
if not all(valid_outputs):
# op为onnx转换后的tvm表示
tup = op.astuple()
# TupleWrapper can also wrap ops with TupleType outputs
# 从tvm表达式中将有效输出对应的部分挑出来,构造成实际的输出
if isinstance(tup, _expr.Tuple):
# For tuples, we extract the fields instead of using GetTupleItem
outputs = [tup.fields[i] for i, valid in enumerate(valid_outputs) if valid]
else:
# For call nodes, we need to GetTupleItem
outputs = [op[i] for i, valid in enumerate(valid_outputs) if valid]
# Create the new op with valid outputs
if len(outputs) == 1:
op = outputs[0]
# 如果有多个输出并且有无效输出被剔除
elif len(outputs) != outputs_num:
op = _expr.TupleWrapper(_expr.Tuple(outputs), len(outputs))
# Drop invalid outputs for the onnx node
outputs_num = len(outputs)
# 更新onnx 节点的输出表, string类型
node_output = [output for output in node_output if output != ""]
assert (
len(node_output) == outputs_num
), "Number of output mismatch {} vs {} in {}.".format(
len(node_output), outputs_num, op_name
)
# 将输出加入节点表, 节点的值为节点的tvm表示
if outputs_num == 1:
self._nodes[node_output[0]] = op
else:
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
# now return the outputs
# graph.output为onnx模型的输出name, 这里从tvm nodes表中得到对应的值(输出节点的tvm表示)
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
# If requested, directly return the converted expressions.
# get_output_expr为调用from_onnx传入的参数,表示当前调用仅仅只是为了获取到输出表达式
if get_output_expr:
return outputs
## Maintain the order of inputs and parameters from the ONNX graph, but only include
## those parameters that are needed to execute the relay graph
# 获取模型输出的tvm表示中的所有的free_var
free_vars = analysis.free_vars(outputs)
nodes = {v: k for k, v in self._nodes.items()}
free_vars = [nodes[var] for var in free_vars]
for i_name in self._params:
# 如果有权重参数在这些free_var中
if i_name in free_vars and i_name not in self._inputs:
# 将这些权重加入到模型的输入中
self._inputs[i_name] = self._nodes[i_name]
# Create a function from our output expression and all input variables.
# 由模型输入, 输出表达式依赖的权重和输出表达式生成一个function
func = _function.Function([v for k, v in self._inputs.items()], outputs)
# 返回表达式和所有权重
return IRModule.from_expr(func), self._params
参考
【从零开始学TVM】三,基于ONNX模型结构了解TVM的前端
https://zhuanlan.zhihu.com/p/365800737