提取深度学习模型的计算图

引言

在深度学习的世界里,理解和优化模型是提高性能的关键。计算图作为构建和训练神经网络的基础,提供了模型运算的可视化表示。本文将探讨基于PyTorch,如何来提取深度学习模型的计算图。

将深度学习的模型变成机器能识别的机器码,即AI框架编译器需要实现的工作。PyTorch 将编译器分解为三个部分,获取图即为本文讨论的部分。
● 图获取(graph acquisition)
● 图下降(graph lowering)
● 图编译(graph compilation)

pytorch 1.0之前构建了 torch.jit.trace、TorchScript、FX tracing、Lazy Tensors工具来实现上文说的编译器三部曲,其中 torch.jit.trace、FX tracing可以实现提取计算图,但是效果不尽人意。

pytorch2.0提供了以下的工具:
● TorchDynamo:可靠且快速地获取计算图
● TorchInductor:使用 define-by-run IR 的快速代码生成。TorchInductor 使用 pythonic define-by-run loop level IR 自动将 PyTorch 模型映射到 GPU 上生成的 Triton 代码和 CPU 上的 C++/OpenMP。
● AOTAutograd:在 AOT 计算图中复用 Autograd。AOTAutograd 利用 PyTorch 的 torch_dispatch 扩展机制来追踪 Autograd 引擎,能够“提前”捕获反向传播,能够使用 TorchInductor 同时加速前向和反向传播。
● PrimTorch:稳定的原始运算符

计算图

计算图本质是一种有向无环图(DAG),其中节点(Node)表示操作或数学函数,边(Edge)表示张量或数据。计算图为深度学习中的前向传播(forward propagation)和反向传播(backward propagation)提供了一个可视化的框架,它能清楚地展示数据是如何流动和操作的。这种表示方法有助于TensorFlow、PyTorch等框架在后端执行自动微分计算。

早期的图模式或者叫define-and-run的静态图框架有Caffe,TensorFlow等,它们设计了一个表示图的IR,用户通过调用这些框架提供的API来构建IR。然后我们可以在这个IR上做程序微分,将IR切分到设备上实现并行,量化,性能优化等等。

在PyTorch中,计算图是随着代码计算过程动态构建的。“计算图捕获”(Graph Capture)就是通过分析/运行代码来捕获计算图的过程。捕获了计算图之后,PyTorch能够得到整个计算过程的全局视角,从而有可能进行更好的优化。

计算图捕获

如果计算过程比较简单,计算图捕获就比较容易。然而,由于PyTorch的动态特性,计算过程中各种可能都会出现,包括但不限于:使用了条件判断、计算与变量的形状有关、调用了其它包(比如numpy、scipy)、调用了其它语言的扩展(比如Rust、C++)等等。
三个例子,来分别说明这三个难题:

# 包含条件判断
def conditional_computation(x):
    if x.sum() < 0:
        return x + 1
    else:
        return x - 1

# 包含形状相关的代码
def shape_dependent(x):
    bsz, *sizes = x.shape
    return x.reshape(bsz, -1)

# 包含外部代码
def external_code(x):
    import numpy as np
    return x + torch.from_numpy(np.random.random((5, 5, 5)))

pytorch 1.0 的torch.fx.symbolic_trace

from torch import fx
fx_model = fx.symbolic_trace(f)

fx_model.forward就是新的函数,fx_model.code是fx_model.forward对应的代码(的字符串表示),fx_model.graph就是捕获得到的计算图。
fx.symbolic_trace捕获期间并不执行任何实际运算,它默认假设函数的参数(这里就是x)全部都是torch.Tensor类型,用Proxy类型的变量来作为输入,记录在输入参数上执行的各种操作。fx.symbolic_trace 只对forward函数进行追踪,注册在forward上的hook函数不会被追踪。

并且无法处理上述的三个问题,只知道x是一个torch.Tensor,无法处理x.shape,也无法获知x.sum() < 0的值到底是多少、应该走哪一个分支。无法处理与x无关的代码。只能知道x与一个torch.Tensor相加,不知道其实调用了np.random.random,也不知道这个量每次调用都会改变。

pytorch 1.0 的torch.jit.trace

input = torch.randn(5, 5, 5)
f_traced = torch.jit.trace(f, input)

f_traced.graph存储了计算图,f_traced.code存储了计算图转化而来的代码。
即时追踪(Just-In-Time Tracing)一般是在运行第一个真实的输入数据时进行追踪,并通过追踪一些预定义算子的调用来实现计算图捕获。因此,相比于符号追踪,即时追踪在追踪过程中能够使用元信息(如形状)和值(如求和大小判断)。只要是能跑通的代码,一般都能进行jit Tracing。

但是无法解决上述的问题一、问题三。

pytorch 2.0 的torch._dynamo

torch._dynamo相关功能,是PyTorch再次骄傲地把版本号升级到2.0的理由。其本质上通过利用Python解释器提供的API,劫持全部的函数调用,分析字节码并从中获取计算图及设置守卫条件。dynamo就指的torch.compile一系列技术,不再区分dynamo和torch.compile。

def custom_backend(gm, example_inputs):
    print(gm.compile_subgraph_reason)
    print(gm.graph)
    print(gm.code)
    return gm.forward

opt_f = torch.compile(f, backend=custom_backend)
output = opt_f(input)

torch-mlir 继承自 pytorch 2.0 的 torch._dynamo 的TorchFx

目前torch-mlir 项目有两个主要的 API :torch_mlir.torchscript.compile 和 torch_mlir.fx.export_and_import。

  • 第一条路径是旧项目 pt1 代码的一部分 (torch_mlir.torchscript.compile),允许用户测试编译器的输出到不同的 MLIR 方言
  • 第二条路径 (torch_mlir.fx.export_and_import)允许用户导入任意 Python
    可调用对象(nn.Module、函数或方法)的合并 torch.export.ExportedProgram 实例,并输出到 torch dialect mlir 模块。
def run(f):
    print(f"{f.__name__}")
    print("-" * len(f.__name__))
    f()
    print()

def save_mlir(module, name, ir):
    module_strs = str(module)
    mlir_name = name + ".mlir"
    cwd = os.getcwd()
    mlir_path = os.path.join(cwd,"iree_test", ir)
    if not os.path.exists(mlir_path):
        os.makedirs(mlir_path)
    with open(os.path.join(mlir_path, mlir_name), 'w') as f:
        f.write(module_strs)

class Transform:
    def __init__(self, f, *args, **kwargs):
        self.f = f
        self.kernel_name = f.__class__.__name__
        self.args = args
        self.kwargs = kwargs
        self.module = None
    
    def get_torchIR(self):
        self.module = fx.export_and_import(self.f, *self.args, **self.kwargs)
        print("aten ir:")
        print(self.module)
        save_mlir(self.module, self.kernel_name, "torch-aten")

    def lower_linalg(self):
        run_pipeline_with_repro_report(
            self.module,
            (
                "builtin.module("
                "func.func(torch-decompose-complex-ops),"
                "torch-backend-to-linalg-on-tensors-backend-pipeline)"
            ),
            "Lowering TorchFX IR -> Linalg IR",
            enable_ir_printing=False,
        )
        print("linalg ir:")
        print(self.module)
        save_mlir(self.module, self.kernel_name, "linalg-ir")
    
    def run(self):
        self.get_torchIR()
        self.lower_linalg()
        # return self.module

@run
def test_sigmoid():
    class Sigmoid(torch.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x:Tensor) -> Tensor:
            return torch.sigmoid(x)

    sigmoid = Transform(Sigmoid(), torch.randn(128, 128))
    sigmoid.run()

参考

https://mp.weixin.qq.com/s/JENCa_GNGPHhOspGb79ugA
https://pytorch.org/get-started/pytorch-2.0
https://cloud.tencent.com/developer/article/2203301
https://zhuanlan.zhihu.com/p/644590863

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值