【TVM源码学习笔记】2 模型导入from_onnx

在前文模型加载时,使用relay.frontend.from_onnx(onnx_model, shape_dict)是将onnx模型转换为TVM可以识别的Graph IR。要理解这一流程,需要对onnx模型定义有基础的了解。

1.onnx模型文件简介

onnx模型的数据定义参见(onnx/onnx.proto at main · onnx/onnx · GitHub)onnx.proto文件。onnx模型的数据类型有如下几种:

        ModelProto:加载了一个onnx模型之后获得的就是一个ModelProto,它包含了一些版本信息,生产者信息和一个GraphProto。

        GraphProto:在GraphProto里面又包含了四个repeated数组,它们分别是node(NodeProto类型),input(ValueInfoProto类型),output(ValueInfoProto类型)和initializer(TensorProto类型);

        NodeProto:网络节点数据类型。GraphProto使用一个NodeProto数组记录网络的所有节点。每个节点都会有两个string类型的数组:input和output,表示当前节点的输入数据的源节点和输出数据的目的节点。通过input和output的指向关系来构建出一个深度学习模型的网络结构

        ValueInfoProto:GraphProto中有两个ValueInfoProto类型的数组:input和output,分别存放网络的输入输出;

        TensorProto:张量类型。GraphProto中定义了一个该类型的数组:initializer,用于存放网络的常量输入,也就是模型的权重参数。数组中的所有张量必须是有名字的。并且这些张量名字也出现在前述的input数组中。

        AttributeProto:网络节点(NodeProto类型)的属性类型,用来描述该节点的属性信息,比如Conv节点或者说卷积层的属性包含group,pad,strides等等。

        这里要注意一下, GraphProto中的input数组不仅包含我们一般理解中的图片输入的那个节点,还包含了模型中所有的权重。例如,Conv层里面的W权重实体是保存在initializer中的,那么相应的会有一个同名的元素在input中。其背后的逻辑应该是把权重也看成模型的输入,并通过initializer中的权重实体来对这个输入做初始化。 

initializer和input中都有网络权重,那么它们有什么区别呢?initializer是TensorProto类型数组,记录了数据张量的名字、类型、shape等信息。

message TensorProto {
  enum DataType {
    UNDEFINED = 0;
    // Basic types.
    FLOAT = 1;   // float
    UINT8 = 2;   // uint8_t
    INT8 = 3;    // int8_t
    UINT16 = 4;  // uint16_t
    INT16 = 5;   // int16_t
    INT32 = 6;   // int32_t
    INT64 = 7;   // int64_t
    STRING = 8;  // string
    BOOL = 9;    // bool
    FLOAT16 = 10;

    DOUBLE = 11;
    UINT32 = 12;
    UINT64 = 13;
    COMPLEX64 = 14;     // complex with float32 real and imaginary components
    COMPLEX128 = 15;    // complex with float64 real and imaginary components
    BFLOAT16 = 16;
  }

  repeated int64 dims = 1;
  optional int32 data_type = 2;
  
  message Segment {
    optional int64 begin = 1;
    optional int64 end = 2;
  }
  optional Segment segment = 3;
  repeated float float_data = 4 [packed = true];
  repeated int32 int32_data = 5 [packed = true];
  repeated bytes string_data = 6;
  repeated int64 int64_data = 7 [packed = true];
  optional string name = 8; // namespace Value
  optional string doc_string = 12;
  optional bytes raw_data = 9;
  repeated StringStringEntryProto external_data = 13;

  enum DataLocation {
    DEFAULT = 0;
    EXTERNAL = 1;
  }

  optional DataLocation data_location = 14;
  repeated double double_data = 10 [packed = true];
  repeated uint64 uint64_data = 11 [packed = true];
}

而input数组是ValueInfoProto类型,只记录了张量的名字:

message ValueInfoProto {
  optional string name = 1;     // namespace Value
  optional TypeProto type = 2;
  optional string doc_string = 3;
}

2. from_onnx流程

relay.frontend.from_onnx就是读入onnx模型,按照前述onnx模型数据结构,解析模型的initializer、input、nodes和output,将算子转换为tvm relay ir算子和表达式,最终得到整个模型的tvm relay IRModule。from_onnx定义在python/tvm/relay/frontend/onnx.py:

def from_onnx(
    model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None
):
    global ONNX_DEFAULT_CONFIGS
    if convert_config is not None:
        ONNX_DEFAULT_CONFIGS.update(convert_config)

    try:
        import onnx

        if hasattr(onnx.checker, "check_model"):
            # try use onnx's own model checker before converting any model
            try:
                onnx.checker.check_model(model)
            except Exception as e:  # pylint: disable=c-extension-no-member, broad-except
                # the checker is a bit violent about errors, so simply print warnings here
                warnings.warn(str(e))
    except ImportError:
        pass
    g = GraphProto(shape, dtype, freeze_params)
    graph = model.graph

    try:
        opset_in_model = model.opset_import[0].version if model.opset_import else 1
    except AttributeError:
        opset_in_model = 1

    if opset is None:
        opset = opset_in_model
    elif opset < opset_in_model:
        warnings.warn(
            ""
            f"You are overwritting original opset ver = {opset_in_model} by lower ver = {opset}. "
            f"That might cause model conversion errors."
        )

    # Use the graph proto as a scope so that ops can access other nodes if needed.
    with g:
        mod, params = g.from_onnx(graph, opset)
    return mod, params

我们可以只用关注下面几行代码

...

g = GraphProto(shape, dtype, freeze_params)
graph = model.graph

...

with g:
        mod, params = g.from_onnx(graph, opset)
    return mod, params

这里生成一个TVM的GraphProto实例g, 然后将传入的onnx模型graph传入GraphProto的from_onnx方法。

GraphProto.from_onnx的参数有onnx模型的TVM GraphProto实例、版本信息、和转换结果返回方式配置:如果设置为true,则只打印onnx模型转为tvm后的表示;默认为false,将返回onnx模型的TVM中间表示数据mod(tvm.IRModule类型)和参数params。

    def from_onnx(self, graph, opset, get_output_expr=False):
        """Construct Relay expression from ONNX graph.

        Onnx graph is a python protobuf object.
        The companion parameters will be handled automatically.
        However, the input names from onnx graph is vague, mixing inputs and
        network weights/bias such as "1", "2"...
        For convenience, we rename the `real` input names to "input_0",
        "input_1"... And renaming parameters to "param_0", "param_1"...

        Parameters
        ----------
        graph : onnx protobuf object
            The loaded onnx graph

        opset : opset version

        get_output_expr: bool
            If set to true, this conversion will return each output expression rather
            than a packaged module. This can be useful when converting subgraphs to
            relay.

        Returns
        -------
        mod : tvm.IRModule
            The returned relay module

        params : dict
            A dict of name: tvm.nd.array pairs, used as pretrained weights
        """
        self.opset = opset
        self._parse_graph_initializers(graph)
        self._parse_graph_input(graph)
        self._check_user_inputs_in_outermost_graph_scope()
        self._check_for_unsupported_ops(graph)
        self._construct_nodes(graph)

        # now return the outputs
        outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
        outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
        # If requested, directly return the converted expressions.
        if get_output_expr:
            return outputs
        ## Maintain the order of inputs and parameters from the ONNX graph, but only include
        ## those parameters that are needed to execute the relay graph
        free_vars = analysis.free_vars(outputs)
        nodes = {v: k for k, v in self._nodes.items()}
        free_vars = [nodes[var] for var in free_vars]
        for i_name in self._params:
            if i_name in free_vars and i_name not in self._inputs:
                self._inputs[i_name] = self._nodes[i_name]
        # Create a function from our output expression and all input variables.
        func = _function.Function([v for k, v in self._inputs.items()], outputs)
        return IRModule.from_expr(func), self._params

2.1 解析onnx模型权重

from_onnx中,首先调用_parse_graph_initializers从onnx模型的initializer数据段中解析转换网络权重数据:

    def _parse_graph_initializers(self, graph):
        """Parse network inputs to relay, aka parameters."""
        # onnx的initializer存放了模型的每个节点的权重参数
        for init_tensor in graph.initializer:
            # 这里的init_tensor.name权重张量的名字.例如网络中某个卷积权重参数W名字为186,偏置参数B名
		    # 字为188, graph.initializer里面就会有两个对应的节点,name分别为186和188

            if not init_tensor.name.strip():
                raise ValueError("Tensor's name is required.")
            # 将权重矩阵转换为tvm的nD-array
            array = self._parse_array(init_tensor)
            if self._freeze_params:
                # 如果设置了feeze_params参数,则看做常量节点(针对不定参数模型的优化?)
                self._nodes[init_tensor.name] = _expr.const(array)
            else:
                # 将解析的参数记录到参数表中
                self._params[init_tensor.name] = array
                # 在节点表中新增一个节点记录该参数
                self._nodes[init_tensor.name] = new_var(
                    init_tensor.name,
                    shape=self._params[init_tensor.name].shape,
                    dtype=self._params[init_tensor.name].dtype,
                )

如前onnx模型结构介绍所述,initializer中的张量必须是有名字的,所以这里对每个元素都判断是否有名字,如果名字字段为空,则认为网络不合法。

_parse_array是将onnx的tensor转换为了tvm的numpy数组。

2.2 解析onnx网络graph的input字段

如前所述,GraphProto中的input数组不仅包含模型的输入,还包含了模型中各节点的权重。也就是将网络的输入节点和各网络节点的权重参数都当作输入

def _parse_graph_input(self, graph):
        for i in graph.input:
            # from onnx v0.2, GraphProto.input has type ValueInfoProto,
            #  and the name is 'i.name'
            # 获取参数的name, shape, type等信息
            i_name, i_shape, d_type, i_shape_name = get_info(i)
            if i_name in self._params:
                # i is a param instead of input
                # 如果是节点参数,则在前面graph.initializer的处理中已经在参数表中添加了对应的节点
                self._num_param += 1
                self._nodes[i_name] = new_var(
                    i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype
                )
            elif i_name in self._nodes:
                continue
            else:
                # 如果是模型的输入
                self._num_input += 1
                self._input_names.append(i_name)
                # self._shape是用户在调用from_onnx的时候传入的模型输入shape参数
                if i_name in self._shape:
                    i_shape = self._shape[i_name]
                else:
                    # 模型的输入shape有不定项
                    if "?" in str(i_shape):
                        warning_msg = (
                            "Input %s has unknown dimension shapes: %s. "
                            "Specifying static values may improve performance"
                            % (i_name, str(i_shape_name))
                        )
                        warnings.warn(warning_msg)
                if isinstance(self._dtype, dict):
                    dtype = self._dtype[i_name] if i_name in self._dtype else d_type
                else:
                    dtype = d_type
                # 在nodes表中加入输入节点
                self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
            #记录模型的输入node
            self._inputs[i_name] = self._nodes[i_name]

get_info函数是解析onnx的ValueInfoProto类型数据,获取数据的name、shape、类型等信息。

2.3 检查是否有不支持的算子

调用_check_for_unsupported_ops检查当前的onnx网络中所有算子是不是都能转换为tvm relay ir 

    def _check_for_unsupported_ops(self, graph):
        # 获取onnx算子到tvm relay ir的转换映射表
        convert_map = _get_convert_map(self.opset)
        unsupported_ops = set()
        for node in graph.node:
            op_name = node.op_type
            if (
                op_name not in convert_map
                and op_name != "Constant"
                and op_name not in _identity_list
            ):          
			    # 如果算子不在映射表中,也不在_identity_list表中,则认为当前算子是TVM不支持的
                unsupported_ops.add(op_name)
        # 如果有不支持的算子,则转换失败
        if unsupported_ops:
            msg = "The following operators are not supported for frontend ONNX: "
            msg += ", ".join(unsupported_ops)
            raise tvm.error.OpNotImplemented(msg)

 _get_convert_map根据(调用from_onnx时传入的)版本号,返回一个表,每个表单元的索引是onnx算子名称,key值是该算子转换为tvm relay ir形式的接口。如果某个onnx算子没有对应的转换接口,就认为tvm当前不支持该算子。具体转换细节可以参考onnx到tvm relay ir的转换。

2.4 创建网络的DAG

然后调用_construct_nodes函数,解析onnx网络的各个节点以及节点连接关系,在tvm中创建网络的DAG(有向无环图) 

    def _construct_nodes(self, graph):
        """Nodes are stored as directed acyclic graph."""
        # 遍历onnx模型的每个算子节点
        for node in graph.node:
            #算子名称,不是节点名称
            op_name = node.op_type
            #解析节点属性
            attr = self._parse_attr(node.attribute)
            # Create and populate input list.
            inputs = onnx_input()
            # 获取节点的所有输入
            for i in node.input:
                if i != "":
                    inputs.append(self._nodes[self._renames.get(i, i)])
                else:
                    inputs.append(None)
            i_name = self._parse_value_proto(node)
            # 获取节点的输出
            node_output = self._fix_outputs(op_name, node.output)
            attr["tvm_custom"] = {}
            attr["tvm_custom"]["name"] = i_name
            attr["tvm_custom"]["num_outputs"] = len(node_output)
            # 将当前onnx节点转换为tvm relay ir
            op = self._convert_operator(op_name, inputs, attr, self.opset)
            if not isinstance(op, _expr.TupleWrapper):
                outputs_num = 1
            else:
                outputs_num = len(op)

            if outputs_num == 1:
                op = fold_constant(op)
            else:
                op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))

            if outputs_num > 1:
                # ONNX supports optional outputs for some nodes.
                # This block searches for missing outputs in the ONNX graph
                # and removes any unneeded ops
                #下面这段代码的意思是:onnx支持一个节点有多个输出,但是有些输出并不实际使用.在转换为tvm relay ir的时候,我们将这些输出剔除掉.具体做法如下:
                # 获取节点的有效输出.如果某个输出没有名字,那么认为这个输出没有被(自己或者其他节点)使用,是无效输出
                valid_outputs = [False] * outputs_num
                for i, output in enumerate(node_output):
                    if output != "":
                        valid_outputs[i] = True
                # If we have outputs ONNX isn't expecting, we need to drop them
                # 如果节点有输出是无效
                if not all(valid_outputs):
                    # 这里op为onnx转换后的tvm表示
                    tup = op.astuple()
                    # TupleWrapper can also wrap ops with TupleType outputs
                    # 从tvm表达式中将有效输出对应的部分挑出来,组成当前节点的实际输出
                    if isinstance(tup, _expr.Tuple):
                        # For tuples, we extract the fields instead of using GetTupleItem
                        outputs = [tup.fields[i] for i, valid in enumerate(valid_outputs) if valid]
                    else:
                        # For call nodes, we need to GetTupleItem
                        outputs = [op[i] for i, valid in enumerate(valid_outputs) if valid]
                    # Create the new op with valid outputs
                    if len(outputs) == 1:
                        op = outputs[0]
                    # 如果有多个输出并且有无效输出被剔除,需要重新打包当前节点的tvm relay ir
                    elif len(outputs) != outputs_num:
                        op = _expr.TupleWrapper(_expr.Tuple(outputs), len(outputs))
                    # Drop invalid outputs for the onnx node
                    # 更新onnx 节点的输出表, string类型
                    outputs_num = len(outputs)                    
                    node_output = [output for output in node_output if output != ""]
            assert (
                len(node_output) == outputs_num
            ), "Number of output mismatch {} vs {} in {}.".format(
                len(node_output), outputs_num, op_name
            )
            
            # 将输出加入节点表, 节点的值为节点的tvm表示
            if outputs_num == 1:
                self._nodes[node_output[0]] = op
            else:
                for k, i in zip(list(node_output), range(len(node_output))):
                    self._nodes[k] = op[i]

代码中调用_convert_operator将onnx算子转换为tvm relay ir。(前面_get_convert_map得到的是各个onnx算子的转换接口,接口执行后得到的才是tvm relay ir)。详细的转换流程见onnx到tvm relay ir的转换。

2.5 处理onnx模型输出

        # now return the outputs
        outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
        outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
        # If requested, directly return the converted expressions.
        if get_output_expr:
            return outputs

graph.output是onnx模型的输出,是一个onnx ValueInfoProto数组。_parse_vale_proto(i)返回输出i的名字。而当前self._nodes中是每个节点的输出的tvm relay ir,里面当然也就有网络最后一个节点的。所以第一行的outputs返回了网络所有输出节点的tvm relay ir。因为某个节点的输入是上一个节点的输出,而这上一个节点的输出也是tvm relay ir,被带入当前节点。例如某一个节点的tvm relay ir:

################################################
onnx op node  Convolution110
output:  ['Convolution110_Output_0']
convert to tvm op:  <class 'tvm.relay.expr.Call'>
free_var %Input3: Tensor[(1, 1, 28, 28), float32];
%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);
free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];
%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);
free_var %Parameter6: Tensor[(8, 1, 1), float32];
%2 = add(%1, %Parameter6);
%3 = nn.relu(%2);
%4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]);
%5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);
free_var %Parameter87: Tensor[(16, 8, 5, 5), float32];
nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5])
#####################################################

数据流向上的下一个节点:

################################################
onnx op node  Plus112
output:  ['Plus112_Output_0']
convert to tvm op:  <class 'tvm.relay.expr.Call'>
free_var %Input3: Tensor[(1, 1, 28, 28), float32];
%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);
free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];
%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);
free_var %Parameter6: Tensor[(8, 1, 1), float32];
%2 = add(%1, %Parameter6);
%3 = nn.relu(%2);
%4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]);
%5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);
free_var %Parameter87: Tensor[(16, 8, 5, 5), float32];
%6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]);
free_var %Parameter88: Tensor[(16, 1, 1), float32];
add(%6, %Parameter88)
################################################

其实这里下游节点只是最后一行那个add算子,其他的都是上游,从网络输入结点开始的各个节点的累加。 

这样按照数据流向叠加到最后整个模型的输出,得到的就是整个模型的tvm relay ir。这里代码第一行的outputs得到的就是整个网络的tvm relay ir。第二行将其打包为一个tuple

如果我们在调用GraphProto.from_onnx的时候传入的get_output_expr参数为true,那么模型转换就到此为止了,返回的是模型的tvm relay ir。但是我们在编译运行模型的脚本中调用的relay.frontend.from_onnx接口(这个接口里面调用了GraphProto.from_onnx)没有这个参数,所以这里不会返回。

2.6 打包模型转换输出

        ## Maintain the order of inputs and parameters from the ONNX graph, but only include
        ## those parameters that are needed to execute the relay graph
        free_vars = analysis.free_vars(outputs)
        nodes = {v: k for k, v in self._nodes.items()}
        free_vars = [nodes[var] for var in free_vars]
        for i_name in self._params:
            if i_name in free_vars and i_name not in self._inputs:
                self._inputs[i_name] = self._nodes[i_name]
        # Create a function from our output expression and all input variables.
        func = _function.Function([v for k, v in self._inputs.items()], outputs)
        return IRModule.from_expr(func), self._params

 这里的流程:

1. 调用python/tvm/relay/analysis/analysis.py的free_vars接口,采用post DFS算法,从网络的输出开始遍历网络的tvm relay ir,找到free变量(是什么东西?按照官方文档,应该是权重之类的);

2. 然后从节点表中获取这些fee变量对应的节点;

3. 将这些节点加入网络的输入表_inputs中;

4. 调用_function.Function,传入网络的输入,参数和转换后的网络表示,得到_func;

5. 最后返回网络的tvm表达和权重参数

返回的网络tvm表达是什么样子呢?我们可以在https://blog.csdn.net/zx_ros/article/details/125894033的模型编译运行脚本中直接打印返回的mod看看:

......
shape_dict = {input_name: data.shape}
# 导入onnx模型
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
print(mod)

...

输出

dl@dl:~/tvm_learning$ python3 mnist_onnx.py 
/home/dl/tvm/python/tvm/driver/build_module.py:267: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  warnings.warn(
def @main(%Input3: Tensor[(1, 1, 28, 28), float32] /* ty=Tensor[(1, 1, 28, 28), float32] */) -> Tensor[(1, 10), float32] {
  %0 = nn.pad(%Input3, 0f /* ty=float32 */, pad_width=[[0i64, 0i64], [0i64, 0i64], [2i64, 2i64], [2i64, 2i64]]) /* ty=Tensor[(1, 1, 32, 32), float32] */;
  %1 = nn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(8, 1, 5, 5), float32] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]) /* ty=Tensor[(1, 8, 28, 28), float32] */;
  %2 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(8, 1, 1), float32] */) /* ty=Tensor[(1, 8, 28, 28), float32] */;
  %3 = nn.relu(%2) /* ty=Tensor[(1, 8, 28, 28), float32] */;
  %4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 8, 14, 14), float32] */;
  %5 = nn.pad(%4, 0f /* ty=float32 */, pad_width=[[0i64, 0i64], [0i64, 0i64], [2i64, 2i64], [2i64, 2i64]]) /* ty=Tensor[(1, 8, 18, 18), float32] */;
  %6 = nn.conv2d(%5, meta[relay.Constant][2] /* ty=Tensor[(16, 8, 5, 5), float32] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]) /* ty=Tensor[(1, 16, 14, 14), float32] */;
  %7 = add(%6, meta[relay.Constant][3] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 14, 14), float32] */;
  %8 = nn.relu(%7) /* ty=Tensor[(1, 16, 14, 14), float32] */;
  %9 = nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %10 = reshape(%9, newshape=[1, 256]) /* ty=Tensor[(1, 256), float32] */;
  %11 = nn.dense(%10, meta[relay.Constant][4] /* ty=Tensor[(10, 256), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(1, 10), float32] */;
  add(%11, meta[relay.Constant][5] /* ty=Tensor[(1, 10), float32] */) /* ty=Tensor[(1, 10), float32] */
}

如果进一步探索下_function.Function都干了些什么,可以看到最后是进入到C++代码中,执行了下面的C++ lamabda函数,返回一个C++的Function句柄:

TVM_REGISTER_GLOBAL("relay.ir.Function")
    .set_body_typed([](tvm::Array<Var> params, Expr body, Type ret_type,
                       tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
      return Function(params, body, ret_type, ty_params, attrs);
    });

同样的,IRModule.from_expr最后调用的是C++代码中TVM_REGISTER_GLOBAL("ir.Module_FromExpr")注册的接口,接口函数执行IRModule::FromExpr,FromExpr调用IRModule::FromExprInContext,生成一个C++端的IRModule实例

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值