【TVM源码学习笔记】2. from_onnx流程分析

直接使用tvm自带的onnx模型(tests/micro/testdata/mnist/mnist-8.onnx)来debug下TVM的模型编译流程,代码如下

import onnx
import numpy as np
import tvm
from tvm import te
import tvm.relay as relay
from PIL import Image

onnx_model = onnx.load('mnist-8.onnx')

image_path = '0.png'
img = Image.open(image_path).resize((28, 28))
img = img.convert('L')
img = np.array(img, dtype=np.float32)
x = img.reshape((1, 1, 28, 28))

target = "llvm"

input_name = "Input3"
shape_dict = {input_name: x.shape}
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)

with tvm.transform.PassContext(opt_level=1):
    intrp = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target)

######################################################################
# Execute on TVM
# ---------------------------------------------
dtype = "float32"
tvm_output = intrp.evaluate()(tvm.nd.array(x.astype(dtype)), **params).asnumpy()

print(np.argmax(tvm_output))

这里relay.frontend.from_onnx(onnx_model, shape_dict)是将onnx模型转换为TVM可以识别的Graph IR。要理解这一流程,需要对onnx模型定义有基础的了解。

1.onnx模型文件简介

onnx模型的数据定义参见(https://github.com/onnx/onnx/blob/main/onnx/onnx.proto)onnx.proto文件。onnx模型的数据类型有如下几种:

        ModelProto

        GraphProto

        NodeProto

        ValueInfoProto

        TensorProto

        AttributeProto

        加载了一个onnx模型之后获得的就是一个ModelProto,它包含了一些版本信息,生产者信息和一个GraphProto。 在GraphProto里面又包含了四个repeated数组,它们分别是node(NodeProto类型),input(ValueInfoProto类型),output(ValueInfoProto类型)和initializer(TensorProto类型)。其中node中存放了模型中所有的计算节点,input存放了模型的输入节点,output存放了模型中所有的输出节点,initializer存放了模型的所有权重参数。

        ONNX的每个计算节点都会有input和output两个数组,这两个数组是string类型,通过input和output的指向关系来构建出一个深度学习模型的网络结构。

        这里要注意一下, GraphProto中的input数组不仅包含我们一般理解中的图片输入的那个节点,还包含了模型中所有的权重。例如,Conv层里面的W权重实体是保存在initializer中的,那么相应的会有一个同名的输入在input中。其背后的逻辑应该是把权重也看成模型的输入,并通过initializer中的权重实体来对这个输入做初始化,即一个赋值的过程。

initializer和input中都有网络权重,那么它们有什么区别呢?initializer是TensorProto类型数组

message TensorProto {
  enum DataType {
    UNDEFINED = 0;
    // Basic types.
    FLOAT = 1;   // float
    UINT8 = 2;   // uint8_t
    INT8 = 3;    // int8_t
    UINT16 = 4;  // uint16_t
    INT16 = 5;   // int16_t
    INT32 = 6;   // int32_t
    INT64 = 7;   // int64_t
    STRING = 8;  // string
    BOOL = 9;    // bool
    FLOAT16 = 10;

    DOUBLE = 11;
    UINT32 = 12;
    UINT64 = 13;
    COMPLEX64 = 14;     // complex with float32 real and imaginary components
    COMPLEX128 = 15;    // complex with float64 real and imaginary components
    BFLOAT16 = 16;
  }

  repeated int64 dims = 1;
  optional int32 data_type = 2;
  
  message Segment {
    optional int64 begin = 1;
    optional int64 end = 2;
  }
  optional Segment segment = 3;
  repeated float float_data = 4 [packed = true];
  repeated int32 int32_data = 5 [packed = true];
  repeated bytes string_data = 6;
  repeated int64 int64_data = 7 [packed = true];
  optional string name = 8; // namespace Value
  optional string doc_string = 12;
  optional bytes raw_data = 9;
  repeated StringStringEntryProto external_data = 13;

  enum DataLocation {
    DEFAULT = 0;
    EXTERNAL = 1;
  }

  optional DataLocation data_location = 14;
  repeated double double_data = 10 [packed = true];
  repeated uint64 uint64_data = 11 [packed = true];
}

而input数组是ValueInfoProto类型:

message ValueInfoProto {
  optional string name = 1;     // namespace Value
  optional TypeProto type = 2;
  optional string doc_string = 3;
}

猜测initializer里面是记录的权重完整数据,而input数组中只是记录了权重的名字和类型等描述信息。

        最后,每个计算节点中还包含了一个AttributeProto数组,用来描述该节点的属性,比如Conv节点或者说卷积层的属性包含group,pad,strides等等。

        ONNX是把一个网络的每一层或者说一个算子当成节点node,使用这些Node去构建一个Graph,即一个网络。将Graph和其它的生产者信息,版本信息等合并在一起生成一个Model,也即是最终的ONNX模型文件。

2. from_onnx流程

relay.frontend.from_onnx就是读入onnx模型,按照前述onnx模型数据结构,解析模型的initializer、input、nodes和output,将算子转换为tvm relay ir算子和表达式,最终得到整个模型的tvm relay IRModule。from_onnx定义在python/tvm/relay/frontend/onnx.py:

def from_onnx(
    model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None
):
    global ONNX_DEFAULT_CONFIGS
    if convert_config is not None:
        ONNX_DEFAULT_CONFIGS.update(convert_config)

    try:
        import onnx

        if hasattr(onnx.checker, "check_model"):
            # try use onnx's own model checker before converting any model
            try:
                onnx.checker.check_model(model)
            except Exception as e:  # pylint: disable=c-extension-no-member, broad-except
                # the checker is a bit violent about errors, so simply print warnings here
                warnings.warn(str(e))
    except ImportError:
        pass
    g = GraphProto(shape, dtype, freeze_params)
    graph = model.graph

    try:
        opset_in_model = model.opset_import[0].version if model.opset_import else 1
    except AttributeError:
        opset_in_model = 1

    if opset is None:
        opset = opset_in_model
    elif opset < opset_in_model:
        warnings.warn(
            ""
            f"You are overwritting original opset ver = {opset_in_model} by lower ver = {opset}. "
            f"That might cause model conversion errors."
        )

    # Use the graph proto as a scope so that ops can access other nodes if needed.
    with g:
        mod, params = g.from_onnx(graph, opset)
    return mod, params

我们可以只用关注下面几行代码

...

g = GraphProto(shape, dtype, freeze_params)
graph = model.graph

...

with g:
        mod, params = g.from_onnx(graph, opset)
    return mod, params

这里生成一个TVM的GraphProto实例g, 然后将传入的onnx模型graph传入GraphProto的from_onnx方法,得到模型的TVM中间表示mod(tvm.IRModule类型)和参数params。GraphProto的from_onnx的流程如下:

def from_onnx(self, graph, opset, get_output_expr=False):
	"""Construct Relay expression from ONNX graph.

	Onnx graph is a python protobuf object.
	The companion parameters will be handled automatically.
	However, the input names from onnx graph is vague, mixing inputs and
	network weights/bias such as "1", "2"...
	For convenience, we rename the `real` input names to "input_0",
	"input_1"... And renaming parameters to "param_0", "param_1"...

	Parameters
	----------
	graph : onnx protobuf object
		The loaded onnx graph

	opset : opset version

	get_output_expr: bool
		If set to true, this conversion will return each output expression rather
		than a packaged module. This can be useful when converting subgraphs to
		relay.

	Returns
	-------
	mod : tvm.IRModule
		The returned relay module

	params : dict
		A dict of name: tvm.nd.array pairs, used as pretrained weights
	"""

	self.opset = opset
	# parse network inputs to relay, aka parameters
	# onnx的initializer存放了模型的每个节点的权重参数
	for init_tensor in graph.initializer:
		# 注意这里的init_tensor.name并不是网络算子的名字,而是权重的命令.例如一个resnet18网络中有
		# 卷积节点名字为resnetv15_stage1_conv2_fwd,这个节点的权重参数W名字为186,偏置参数B名字为188
		# graph.initializer里面就会有两个对应的节点,name分别为186和188
		if not init_tensor.name.strip():
			raise ValueError("Tensor's name is required.")
		# 将权重矩阵转换为tvm的nD-array
		array = self._parse_array(init_tensor)
		if self._freeze_params:
			# 如果设置了feeze_params参数,则所有的参数都认为是常量(针对不定参数模型的优化?)
			self._nodes[init_tensor.name] = _expr.const(array)
		else:
			# 将解析的参数记录到参数表中
			self._params[init_tensor.name] = array
			# 在node表中加入该参数
			self._nodes[init_tensor.name] = new_var(
				init_tensor.name,
				shape=self._params[init_tensor.name].shape,
				dtype=self._params[init_tensor.name].dtype,
			)
	# GraphProto中的input数组不仅包含模型的输入,还包含了模型中各节点的权重, 
	# 也就是将外部的图片数据输入和网络自身的权重参数都当作输入
	for i in graph.input:
		# from onnx v0.2, GraphProto.input has type ValueInfoProto,
		#  and the name is 'i.name'
		# 获取参数的name, shape, type等信息
		i_name, i_shape, d_type, i_shape_name = get_info(i)
		# 如果是节点参数,则在前面graph.initializer的处理中已经在参数表中添加了对应的节点
		if i_name in self._params:
			# i is a param instead of input
			self._num_param += 1
			self._params[i_name] = self._params.pop(i_name)
			self._nodes[i_name] = new_var(
				i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype
			)
		elif i_name in self._nodes:
			continue
		# 如果是模型的输入
		else:
			self._num_input += 1
			self._input_names.append(i_name)
			# self._shape是用户在调用from_onnx的时候传入的模型输入shape参数
			if i_name in self._shape:
				i_shape = self._shape[i_name]
			else:
				# 模型的输入shape有不定项
				if "?" in str(i_shape):
					warning_msg = (
						"Input %s has unknown dimension shapes: %s. "
						"Specifying static values may improve performance"
						% (i_name, str(i_shape_name))
					)
					warnings.warn(warning_msg)
			if isinstance(self._dtype, dict):
				dtype = self._dtype[i_name] if i_name in self._dtype else d_type
			else:
				dtype = d_type
			# 在nodes表中加入输入节点
			self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
		#记录模型的输入node
		self._inputs[i_name] = self._nodes[i_name]
	# Only check user inputs in the outer-most graph scope.
	if self._old_manager is None:
		assert all(
			[name in self._input_names for name in self._shape.keys()]
		), "User specified the shape for inputs that weren't found in the graph: " + str(
			self._shape
		)
	# get list of unsupported ops
	# 获取onnx算子和tvm算子的映射表
	convert_map = _get_convert_map(opset)
	unsupported_ops = set()
	for node in graph.node:
		op_name = node.op_type
		if (
			op_name not in convert_map
			and op_name != "Constant"
			and op_name not in _identity_list
		):
			# 如果算子不在映射表中,也不在_identity_list表中,则认为当前算子是TVM不支持的
			unsupported_ops.add(op_name)
	if unsupported_ops:
		msg = "The following operators are not supported for frontend ONNX: "
		msg += ", ".join(unsupported_ops)
		# 如果有不支持的算子,则转换失败
		raise tvm.error.OpNotImplemented(msg)
	# construct nodes, nodes are stored as directed acyclic graph
	# 处理onnx模型文件中的node数据, onnx node的proto类型为NodeProto,参见https://gitee.com/mirrors/ONNX/blob/main/onnx/onnx.proto
	# 其中input/output都是string类型
	for node in graph.node:
		# 获取算子类型和属性
		op_name = node.op_type
		attr = self._parse_attr(node.attribute)
		# Create and populate input list.
		# 创建一个(算子)输入实例
		inputs = onnx_input()
		# 获取当前(onnx)节点的所有输入(name)
		for i in node.input:
			if i != "":
				inputs.append(self._nodes[self._renames.get(i, i)])
			else:
				# 有些输入没使用?
				inputs.append(None)
		i_name = self._parse_value_proto(node)
		# 获取onnx节点的输出,为string类型,是输出的name
		node_output = self._fix_outputs(op_name, node.output)
		# 记录onnx节点的属性
		attr["tvm_custom"] = {}
		attr["tvm_custom"]["name"] = i_name
		attr["tvm_custom"]["num_outputs"] = len(node_output)
		"""
		将onnx 算子节点转换为对应的tvm表示. 
		例如onnx模型中一个卷积算子节点,返回的op类型为tvm.relay.expr.Call, 数据为:
		free_var %data: Tensor[(1, 3, 224, 224), float32];
		free_var %v174: Tensor[(64, 3, 7, 7), float32];
		%0 = nn.conv2d(%data, %v174, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]);
		free_var %v176: Tensor[(64), float32];
		nn.bias_add(%0, %v176)            
		"""            
		op = self._convert_operator(op_name, inputs, attr, opset)
		# 获取算子输出个数
		if not isinstance(op, _expr.TupleWrapper):
			outputs_num = 1
		else:
			outputs_num = len(op)

		if outputs_num == 1:
			op = fold_constant(op)
		else:
			op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))
		# 如果算子有多个输出
		if outputs_num > 1:
			# ONNX supports optional outputs for some nodes.
			# This block searches for missing outputs in the ONNX graph
			# and removes any unneeded ops
			# 获取节点的有效输出.这里认为当算子有多个输出并且某个输出没有被使用时,对应的name为空字符串
			valid_outputs = [False] * outputs_num
			for i, output in enumerate(node_output):
				if output != "":
					valid_outputs[i] = True
			# If we have outputs ONNX isn't expecting, we need to drop them
			# 如果节点有输出是无效
			if not all(valid_outputs):
				# op为onnx转换后的tvm表示
				tup = op.astuple()
				# TupleWrapper can also wrap ops with TupleType outputs
				# 从tvm表达式中将有效输出对应的部分挑出来,构造成实际的输出
				if isinstance(tup, _expr.Tuple):
					# For tuples, we extract the fields instead of using GetTupleItem                        
					outputs = [tup.fields[i] for i, valid in enumerate(valid_outputs) if valid]
				else:
					# For call nodes, we need to GetTupleItem
					outputs = [op[i] for i, valid in enumerate(valid_outputs) if valid]
				# Create the new op with valid outputs
				if len(outputs) == 1:
					op = outputs[0]
				# 如果有多个输出并且有无效输出被剔除 
				elif len(outputs) != outputs_num:
					op = _expr.TupleWrapper(_expr.Tuple(outputs), len(outputs))
				# Drop invalid outputs for the onnx node
				outputs_num = len(outputs)
				# 更新onnx 节点的输出表, string类型
				node_output = [output for output in node_output if output != ""]
		assert (
			len(node_output) == outputs_num
		), "Number of output mismatch {} vs {} in {}.".format(
			len(node_output), outputs_num, op_name
		)
		# 将输出加入节点表, 节点的值为节点的tvm表示
		if outputs_num == 1:
			self._nodes[node_output[0]] = op
		else:
			for k, i in zip(list(node_output), range(len(node_output))):
				self._nodes[k] = op[i]

	# now return the outputs
	# graph.output为onnx模型的输出name, 这里从tvm nodes表中得到对应的值(输出节点的tvm表示)
	outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
	outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
	# If requested, directly return the converted expressions.
	# get_output_expr为调用from_onnx传入的参数,表示当前调用仅仅只是为了获取到输出表达式
	if get_output_expr:
		return outputs
	## Maintain the order of inputs and parameters from the ONNX graph, but only include
	## those parameters that are needed to execute the relay graph
	# 获取模型输出的tvm表示中的所有的free_var
	free_vars = analysis.free_vars(outputs)
	nodes = {v: k for k, v in self._nodes.items()}
	free_vars = [nodes[var] for var in free_vars]
	for i_name in self._params:
		# 如果有权重参数在这些free_var中
		if i_name in free_vars and i_name not in self._inputs:
			# 将这些权重加入到模型的输入中
			self._inputs[i_name] = self._nodes[i_name]
	# Create a function from our output expression and all input variables.
	# 由模型输入, 输出表达式依赖的权重和输出表达式生成一个function
	func = _function.Function([v for k, v in self._inputs.items()], outputs)
	# 返回表达式和所有权重
	return IRModule.from_expr(func), self._params

参考

【从零开始学TVM】三,基于ONNX模型结构了解TVM的前端
https://zhuanlan.zhihu.com/p/365800737

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
TVM的te(Tensor Expression)中,`tvm.te.if_then_else` 函数用于实现条件语句,根据一个布尔条件选择执行不同的计算逻辑。 `if_then_else` 函数的基本语法如下: ```python tvm.te.if_then_else(condition, then_expr, else_expr) ``` 参数说明: - `condition`:布尔条件表达式,用于决定选择哪个分支的计算逻辑。 - `then_expr`:当条件为真时执行的计算表达式。 - `else_expr`:当条件为假时执行的计算表达式。 下面是一个示例代码,演示如何使用 `if_then_else` 函数: ```python import tvm from tvm import te def if_then_else_example(): # 输入张量形状 shape = (4, ) # 创建输入和输出张量 input_tensor = te.placeholder(shape, name='input_tensor', dtype='float32') output_tensor = te.placeholder(shape, name='output_tensor', dtype='float32') # 定义计算 def compute(i): # 根据输入张量的值判断条件 condition = input_tensor[i] > 0 # 根据条件选择执行计算逻辑 then_expr = input_tensor[i] * 2 else_expr = input_tensor[i] / 2 # 使用 if_then_else 函数实现条件选择 return tvm.te.if_then_else(condition, then_expr, else_expr) # 创建计算描述 output = te.compute(shape, compute, name='output') return output.op.body[0] # 创建一个范围上下文 with tvm.target.Target('llvm'): # 构造计算图 stmt = if_then_else_example() # 打印生成的计算图 print(stmt) ``` 在上述示例中,我们定义了一个 `if_then_else_example()` 函数,创建了输入张量和输出张量。然后在 `compute()` 中,根据输入张量的值判断条件,并使用 `if_then_else` 函数实现条件选择:当输入张量大于0时,执行乘以2的计算逻辑;否则,执行除以2的计算逻辑。最后通过 `te.compute()` 创建计算描述,并返回计算图的第一个操作节点。 希望这能解答您的疑问!如果您还有其他问题,请随时提问。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值