tvm.frontend.from_pytorch详细介绍(1)



一、pytorch前端整体转化流程(部分)

1.脚本化的pytorch模型

  脚本化 PyTorch 模型:首先,使用 torch.jit.script 函数对PyTorch 模型进行脚本化。这将把模型转换为 TorchScript 表示,可以独立于 Python 运行时加载和执行。

# 脚本化 PyTorch 模型
scripted_model = torch.jit.script(你的模型)
# 将脚本化的 PyTorch 模型转换为 Relay
relay_model, params = relay.frontend.from_pytorch(scripted_model, input_shapes=[(input_shape,)])

2.内联优化(_run_jit_passes)

在 JIT 过程中,为了处理 prim::CallMethod 的调用,需要执行内联传递操作,将方法调用的函数体内联展开,以便进行进一步的优化或处理。

2.1、内联优化

   这段代码的目的是在PyTorch的JIT编译过程中应用一些特定的转换和优化,以提高生成的图的性能和效率。

def _run_jit_passes(graph, enable_lower_all_tuples=True):
    """The inline pass is necessary to unwrap prim::CallMethod"""
    # pylint: disable=c-extension-no-member
    import torch

    if is_version_greater_than("1.5.1"):
        # This is required for torchvision detection models from 1.6 above
        # It is the same as _jit_pass_inline, except that it has some special
        # case behaviors for some ops such as aten::__interpolate()
        torch._C._jit_pass_onnx_function_substitution(graph)
    else:
        torch._C._jit_pass_inline(graph)

    if enable_lower_all_tuples:

   JIT passes用于对PyTorch图进行转换和优化,以便在部署或执行期间提高性能。如果PyTorch的版本大于1.5.1,它会运行一系列的passes,如果PyTorch的版本低于等于1.5.1,它将只运行一个叫做_jit_pass_inline的pass,该pass用于内联函数调用。

2.2 什么是内联函数

  Inplace操作是指在原地(in-place)修改数据或对象,而不创建新的副本。它通常用于优化内存使用和减少计算开销。
  在PyTorch中,有一些操作允许以inplace的方式进行,即直接修改操作作用的张量,而不创建新的张量对象。这些操作通常以_结尾,例如add_、mul_、div_等。通过使用inplace操作,可以减少内存分配和数据拷贝的开销,提高代码的效率。
  需要注意的是,使用inplace操作时需要小心,因为它会直接修改原始数据,可能导致意外的副作用或不可逆的修改。因此,在使用inplace操作时,应确保了解其行为,并在适当的情况下使用它们。

3.graph中的所有op(get_all_op_names)

  这段代码的目的是获取输入图中所有操作符的名称,并以集合的形式返回。它通过遍历图中的节点和子块,收集并去重所有操作符的名称。调试断点的设置可能是为了在执行过程中检查和调试代码。

def get_all_op_names(graph):
    """ Return all operator names in the input graph """
    nodes = list(graph.nodes())
    prim_with_blocks = ["prim::If", "prim::Loop"]
    for prim in prim_with_blocks:
        prim_nodes = graph.findAllNodes(prim, recurse=True)
        for prim_node in prim_nodes:
            for block in prim_node.blocks():
                nodes += block.nodes()
    return set(node.kind() for node in nodes)

函数的主要逻辑如下:

  1. 获取图中的所有节点列表。
  2. 定义了一个名为prim_with_blocks的列表,其中包含了"prim::If"和"prim::Loop"这两个具有子块(blocks)的操作符名称。
  3. 对于prim_with_blocks列表中的每个操作符名称,遍历图中所有的该操作符节点(包括子块中的节点)。
  4. 将每个节点的子块中的节点也加入到nodes列表中。
  5. 最后,返回一个集合(set),其中包含了nodes列表中每个节点的操作符名称。

3.1 各个变量的值

1 .graph

def get_all_op_names(graph):
(Pdb) p graph
graph(%self.1 : __torch__.TempOpModel,
      %input : Float(1:48, 3:16, 16:1)):
  %2 : __torch__.torch.nn.modules.conv.ConvTranspose1d = prim::GetAttr[name="convtrans"](%self.1)
  %4 : Tensor = prim::GetAttr[name="weight"](%2)
  %5 : None = prim::Constant(), scope: __module.convtrans
  %6 : int = prim::Constant[value=2](), scope: __module.convtrans 
  %7 : int[] = prim::ListConstruct(%6), scope: __module.convtrans
  %8 : int = prim::Constant[value=1](), scope: __module.convtrans 
  %9 : int[] = prim::ListConstruct(%8), scope: __module.convtrans
  %10 : int = prim::Constant[value=1](), scope: __module.convtrans 
  %11 : int[] = prim::ListConstruct(%10), scope: __module.convtrans
  %12 : bool = prim::Constant[value=1](), scope: __module.convtrans 
  %13 : int = prim::Constant[value=0](), scope: __module.convtrans 
  %14 : int[] = prim::ListConstruct(%13), scope: __module.convtrans
  %15 : int = prim::Constant[value=1](), scope: __module.convtrans 
  %16 : bool = prim::Constant[value=0](), scope: __module.convtrans
  %17 : bool = prim::Constant[value=0](), scope: __module.convtrans
  %18 : bool = prim::Constant[value=1](), scope: __module.convtrans 
  %19 : Float(1:198, 6:33, 33:1) = aten::_convolution(%input, %4, %5, %7, %9, %11, %12, %14, %15, %16, %17, %18), scope: __module.convtrans 
  return (%19)

2 .nodes

nodes = list(graph.nodes())
(Pdb) p graph.nodes
<bound method PyCapsule.nodes of graph(%self.1 : __torch__.TempOpModel,
      %input : Float(1:48, 3:16, 16:1)):
  %2 : __torch__.torch.nn.modules.conv.ConvTranspose1d = prim::GetAttr[name="convtrans"](%self.1)
  %4 : Tensor = prim::GetAttr[name="weight"](%2)
  %5 : None = prim::Constant(), scope: __module.convtrans
  %6 : int = prim::Constant[value=2](), scope: __module.convtrans 
  %7 : int[] = prim::ListConstruct(%6), scope: __module.convtrans
  %8 : int = prim::Constant[value=1](), scope: __module.convtrans 
  %9 : int[] = prim::ListConstruct(%8), scope: __module.convtrans
  %10 : int = prim::Constant[value=1](), scope: __module.convtrans 
  %11 : int[] = prim::ListConstruct(%10), scope: __module.convtrans
  %12 : bool = prim::Constant[value=1](), scope: __module.convtrans 
  %13 : int = prim::Constant[value=0](), scope: __module.convtrans
  %14 : int[] = prim::ListConstruct(%13), scope: __module.convtrans
  %15 : int = prim::Constant[value=1](), scope: __module.convtrans 
  %16 : bool = prim::Constant[value=0](), scope: __module.convtrans 
  %17 : bool = prim::Constant[value=0](), scope: __module.convtrans
  %18 : bool = prim::Constant[value=1](), scope: __module.convtrans 
  %19 : Float(1:198, 6:33, 33:1) = aten::_convolution(%input, %4, %5, %7, %9, %11, %12, %14, %15, %16, %17, %18), scope: __module.convtrans 
  return (%19)

3 .p nodes

nodes = list(graph.nodes())
(Pdb) p nodes
[%2 : __torch__.torch.nn.modules.conv.ConvTranspose1d = prim::GetAttr[name="convtrans"](%self.1)
, %4 : Tensor = prim::GetAttr[name="weight"](%2)
, %5 : None = prim::Constant(), scope: __module.convtrans
, %6 : int = prim::Constant[value=2](), scope: __module.convtrans 
, %7 : int[] = prim::ListConstruct(%6), scope: __module.convtrans
, %8 : int = prim::Constant[value=1](), scope: __module.convtrans 
, %9 : int[] = prim::ListConstruct(%8), scope: __module.convtrans
, %10 : int = prim::Constant[value=1](), scope: __module.convtrans 
, %11 : int[] = prim::ListConstruct(%10), scope: __module.convtrans
, %12 : bool = prim::Constant[value=1](), scope: __module.convtrans 
, %13 : int = prim::Constant[value=0](), scope: __module.convtrans
, %14 : int[] = prim::ListConstruct(%13), scope: __module.convtrans
, %15 : int = prim::Constant[value=1](), scope: __module.convtrans 
, %16 : bool = prim::Constant[value=0](), scope: __module.convtrans 
, %17 : bool = prim::Constant[value=0](), scope: __module.convtrans 
, %18 : bool = prim::Constant[value=1](), scope: __module.convtrans 
, %19 : Float(1:198, 6:33, 33:1) = aten::_convolution(%input, %4, %5, %7, %9, %11, %12, %14, %15, %16, %17, %18), scope: __module.convtrans 
]

4、返回结果

  set(node.kind() for node in nodes)是一个表达式,用于创建一个集合(set),其中包含了nodes列表中每个节点的类型(kind)。

  在PyTorch中,node.kind()是用于获取节点类型的方法。每个节点在计算图中都有一个类型,表示该节点所执行的操作或功能。

  node.kind()返回一个表示节点类型的字符串,通常是以prim::开头,后面跟着具体的操作名称或标识符。例如,prim::Add、prim::Mul等。

(Pdb) p op_names
{'aten::_convolution', 'prim::Constant', 'prim::ListConstruct', 'prim::GetAttr'}

二、from_pytorch完整代码


```python
def from_pytorch(
    script_module,
    input_infos,
    custom_convert_map=None,
    default_dtype="float32",
    use_parser_friendly_name=False,
    keep_quantized_weight=False,
    export_renamed_c_graph_path=None,
    preserve_pytorch_scopes=False,
):
    """Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
    The companion parameters will be handled automatically.

    Parameters
    ----------
    script_module : TopLevelTracedModule object
        TorchScripted PyTorch graph
        Note: We currently only support traces (ie: torch.jit.trace(model, input))

    input_infos : List of tuples
        Can be (input name, input shape) or (input name, (input shape, input types))
        Graph level input shape and type list
        The same input names need to be used for deployment, so choose easy to
        remember names (such as: input0, input1)
        e.g.
        [('input0', (1, 2)), ('input1', (3, 4))]
        or
        [('input0', ((1, 2), 'int')), ('input1', ((3, 4), 'float'))]

    custom_convert_map : Dictionary of str to Relay op
        A custom op conversion map in the same format as _convert_map above

    default_type : str
        The default dtype to use when type information is not provided by PyTorch.

    use_parser_friendly_name : bool
        When True, replace '.' with `_' in a original parameter name.
        The Relay text parser treats a variable name followed by a period as a tuple element access,
        so a variable name like "dense.weight" cannot be parsed correctly.
        Use this option when you want to run the AnnotateSpans pass on the imported module.

    keep_quantized_weight : bool
        Return quantized weights and bias, rather than float ones. PyTorch stores quantized weights
        in a custom format, so we cannot directly access 8 bit weights as Numpy arrays. We use
        a PyTorch function to unpack quantized weights into float32 arrays and quantization
        parameters. By default, we return float32 weights and rely on the QNN lowering and the
        Relay constant folding pass to quantize weights at compile time. In BYOC use cases, however,
        we cannot apply the constant folding pass on a QNN graph. If keep_quantized_weight is True,
        we quantize weights in the frontend using a function that is equivalent to
        qnn.op.quantize(...) operating on Numpy arrays.

    export_renamed_c_graph_path : str, optional
        Export the renamed torch._C.Graph to the path.
        During the conversion, variable names in torch._C.Graph will be assigned based on their op
        types. The exported text file can be the reference to spans.

    preserve_pytorch_scopes : bool
        When naming the nodes in the Relay graph, use the "scope name" from the Pytorch model.
        If false, a default namer is used that does not preserve the Pytorch scope names.

    Returns
    -------
    mod : tvm.IRModule
        The module that optimizations will be performed on.

    params : dict of str to tvm.runtime.NDArray
        Dict of converted parameters stored in tvm.runtime.ndarray format
    """
    import torch

    mod = tvm.IRModule()
    prelude = Prelude(mod)
    enable_lower_all_tuples = True

    converter = PyTorchOpConverter(
        prelude, default_dtype, use_parser_friendly_name, preserve_pytorch_scopes
    )

    graph = script_module.graph.copy()

    # Check if lower_all_tuples pass can be enabled
    graph_inputs = list(graph.inputs())
    for inp in graph_inputs:
        if inp.type().kind() == "TupleType" or inp.type().kind() == "ListType":
            enable_lower_all_tuples = False
            break

    _run_jit_passes(graph, enable_lower_all_tuples)
    _redirect_inplace_output(graph)

    if custom_convert_map:
        converter.update_convert_map(custom_convert_map)

    op_names = get_all_op_names(graph)
    converter.report_missing_conversion(op_names)

    is_module = isinstance(script_module, torch.jit.ScriptModule)
    params = script_module.state_dict() if is_module else {}
    outputs = _get_relay_input_vars(
        graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module
    )

    if use_parser_friendly_name:
        new_names = [key.replace(".", "_") for key in params.keys()]
        params = dict(zip(new_names, params.values()))

    # rename _C.Graph here for constructing meaningful source name of graph nodes
    # by doing so, we could Use source_map as the reference to rename model parameters
    source_map = _debug_rename(graph, use_parser_friendly_name, preserve_pytorch_scopes)
    param_vars, tensors, packed_param_map, param_debug_name_map = convert_params(
        graph, params, source_map, use_parser_friendly_name
    )

    tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

    outputs.update(param_vars)

    # For quantized models
    quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"])
    if len(quantized_ops.intersection(set(op_names))) > 0:
        weight_quant_params = qnn_torch.get_weight_quant_params(
            script_module, packed_param_map.values()
        )
        qnn_torch.inline_input_quant_params_for_fx(graph, tensors, param_debug_name_map)
        input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph)
        qnn_torch.add_quant_params_to_outputs(
            outputs,
            packed_param_map,
            weight_quant_params,
            input_scales_for_bias,
            keep_quantized_weight,
        )
        qnn_torch.add_quant_params(tvm_params, weight_quant_params)
        converter.update_convert_map(qnn_torch.convert_map)

    operator_nodes = _get_operator_nodes(
        graph.nodes(),
        converter.source_map,
        converter.op_type_dict,
        use_parser_friendly_name,
        preserve_pytorch_scopes,
    )
    ret_name = _get_input_names(graph.return_node())
    outputs = converter.convert_operators(operator_nodes, outputs, ret_name)

    # ListConstruct kept original python list. Convert to tuple.
    outputs = [_expr.Tuple(output) if isinstance(output, list) else output for output in outputs]

    if len(outputs) > 1:
        ret = _expr.Tuple(outputs)
    else:
        ret = outputs[0]

    # Separate data inputs and parameters to make sure data inputs come first.
    func_args = []
    data_inputs = []
    for arg in _analysis.free_vars(ret):
        if arg.name_hint not in tvm_params.keys():
            data_inputs.append(arg)
        else:
            func_args.append(arg)

    # Ensures the order of data_input is the same as the order of inputs specified in input_info.
    order_input_infos = {
        input_info[0]: len(input_infos) - idx for idx, input_info in enumerate(input_infos)
    }
    data_inputs = sorted(
        data_inputs,
        key=lambda data_input: order_input_infos[data_input.name_hint]
        if data_input.name_hint in order_input_infos
        else -1,
        reverse=True,
    )

    func_args = data_inputs + func_args

    mod["main"] = tvm.relay.Function(func_args, ret)

    if export_renamed_c_graph_path:
        export_c_graph(export_renamed_c_graph_path, graph)

    return transform.RemoveUnusedFunctions()(mod), tvm_params


TVM的te(Tensor Expression)中,`tvm.te.if_then_else` 函数用于实现条件语句,根据一个布尔条件选择执行不同的计算逻辑。 `if_then_else` 函数的基本语法如下: ```python tvm.te.if_then_else(condition, then_expr, else_expr) ``` 参数说明: - `condition`:布尔条件表达式,用于决定选择哪个分支的计算逻辑。 - `then_expr`:当条件为真时执行的计算表达式。 - `else_expr`:当条件为假时执行的计算表达式。 下面是一个示例代码,演示如何使用 `if_then_else` 函数: ```python import tvm from tvm import te def if_then_else_example(): # 输入张量形状 shape = (4, ) # 创建输入和输出张量 input_tensor = te.placeholder(shape, name='input_tensor', dtype='float32') output_tensor = te.placeholder(shape, name='output_tensor', dtype='float32') # 定义计算 def compute(i): # 根据输入张量的值判断条件 condition = input_tensor[i] > 0 # 根据条件选择执行计算逻辑 then_expr = input_tensor[i] * 2 else_expr = input_tensor[i] / 2 # 使用 if_then_else 函数实现条件选择 return tvm.te.if_then_else(condition, then_expr, else_expr) # 创建计算描述 output = te.compute(shape, compute, name='output') return output.op.body[0] # 创建一个范围上下文 with tvm.target.Target('llvm'): # 构造计算图 stmt = if_then_else_example() # 打印生成的计算图 print(stmt) ``` 在上述示例中,我们定义了一个 `if_then_else_example()` 函数,创建了输入张量和输出张量。然后在 `compute()` 中,根据输入张量的值判断条件,并使用 `if_then_else` 函数实现条件选择:当输入张量大于0时,执行乘以2的计算逻辑;否则,执行除以2的计算逻辑。最后通过 `te.compute()` 创建计算描述,并返回计算图的第一个操作节点。 希望这能解答您的疑问!如果您还有其他问题,请随时提问。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值