引言
在深度学习的世界里,理解和优化模型是提高性能的关键。计算图作为构建和训练神经网络的基础,提供了模型运算的可视化表示。本文将探讨基于PyTorch,如何来提取深度学习模型的计算图。
将深度学习的模型变成机器能识别的机器码,即AI框架编译器需要实现的工作。PyTorch 将编译器分解为三个部分,获取图即为本文讨论的部分。
● 图获取(graph acquisition)
● 图下降(graph lowering)
● 图编译(graph compilation)
pytorch 1.0之前构建了 torch.jit.trace、TorchScript、FX tracing、Lazy Tensors工具来实现上文说的编译器三部曲,其中 torch.jit.trace、FX tracing可以实现提取计算图,但是效果不尽人意。
pytorch2.0提供了以下的工具:
● TorchDynamo:可靠且快速地获取计算图
● TorchInductor:使用 define-by-run IR 的快速代码生成。TorchInductor 使用 pythonic define-by-run loop level IR 自动将 PyTorch 模型映射到 GPU 上生成的 Triton 代码和 CPU 上的 C++/OpenMP。
● AOTAutograd:在 AOT 计算图中复用 Autograd。AOTAutograd 利用 PyTorch 的 torch_dispatch 扩展机制来追踪 Autograd 引擎,能够“提前”捕获反向传播,能够使用 TorchInductor 同时加速前向和反向传播。
● PrimTorch:稳定的原始运算符
计算图
计算图本质是一种有向无环图(DAG),其中节点(Node)表示操作或数学函数,边(Edge)表示张量或数据。计算图为深度学习中的前向传播(forward propagation)和反向传播(backward propagation)提供了一个可视化的框架,它能清楚地展示数据是如何流动和操作的。这种表示方法有助于TensorFlow、PyTorch等框架在后端执行自动微分计算。
早期的图模式或者叫define-and-run的静态图框架有Caffe,TensorFlow等,它们设计了一个表示图的IR,用户通过调用这些框架提供的API来构建IR。然后我们可以在这个IR上做程序微分,将IR切分到设备上实现并行,量化,性能优化等等。
在PyTorch中,计算图是随着代码计算过程动态构建的。“计算图捕获”(Graph Capture)就是通过分析/运行代码来捕获计算图的过程。捕获了计算图之后,PyTorch能够得到整个计算过程的全局视角,从而有可能进行更好的优化。
计算图捕获
如果计算过程比较简单,计算图捕获就比较容易。然而,由于PyTorch的动态特性,计算过程中各种可能都会出现,包括但不限于:使用了条件判断、计算与变量的形状有关、调用了其它包(比如numpy、scipy)、调用了其它语言的扩展(比如Rust、C++)等等。
三个例子,来分别说明这三个难题:
# 包含条件判断
def conditional_computation(x):
if x.sum() < 0:
return x + 1
else:
return x - 1
# 包含形状相关的代码
def shape_dependent(x):
bsz, *sizes = x.shape
return x.reshape(bsz, -1)
# 包含外部代码
def external_code(x):
import numpy as np
return x + torch.from_numpy(np.random.random((5, 5, 5)))
pytorch 1.0 的torch.fx.symbolic_trace
from torch import fx
fx_model = fx.symbolic_trace(f)
fx_model.forward就是新的函数,fx_model.code是fx_model.forward对应的代码(的字符串表示),fx_model.graph就是捕获得到的计算图。
fx.symbolic_trace捕获期间并不执行任何实际运算,它默认假设函数的参数(这里就是x)全部都是torch.Tensor类型,用Proxy类型的变量来作为输入,记录在输入参数上执行的各种操作。fx.symbolic_trace 只对forward函数进行追踪,注册在forward上的hook函数不会被追踪。
并且无法处理上述的三个问题,只知道x是一个torch.Tensor,无法处理x.shape,也无法获知x.sum() < 0的值到底是多少、应该走哪一个分支。无法处理与x无关的代码。只能知道x与一个torch.Tensor相加,不知道其实调用了np.random.random,也不知道这个量每次调用都会改变。
pytorch 1.0 的torch.jit.trace
input = torch.randn(5, 5, 5)
f_traced = torch.jit.trace(f, input)
f_traced.graph存储了计算图,f_traced.code存储了计算图转化而来的代码。
即时追踪(Just-In-Time Tracing)一般是在运行第一个真实的输入数据时进行追踪,并通过追踪一些预定义算子的调用来实现计算图捕获。因此,相比于符号追踪,即时追踪在追踪过程中能够使用元信息(如形状)和值(如求和大小判断)。只要是能跑通的代码,一般都能进行jit Tracing。
但是无法解决上述的问题一、问题三。
pytorch 2.0 的torch._dynamo
torch._dynamo相关功能,是PyTorch再次骄傲地把版本号升级到2.0的理由。其本质上通过利用Python解释器提供的API,劫持全部的函数调用,分析字节码并从中获取计算图及设置守卫条件。dynamo就指的torch.compile一系列技术,不再区分dynamo和torch.compile。
def custom_backend(gm, example_inputs):
print(gm.compile_subgraph_reason)
print(gm.graph)
print(gm.code)
return gm.forward
opt_f = torch.compile(f, backend=custom_backend)
output = opt_f(input)
torch-mlir 继承自 pytorch 2.0 的 torch._dynamo 的TorchFx
目前torch-mlir 项目有两个主要的 API :torch_mlir.torchscript.compile 和 torch_mlir.fx.export_and_import。
- 第一条路径是旧项目 pt1 代码的一部分 (torch_mlir.torchscript.compile),允许用户测试编译器的输出到不同的 MLIR 方言
- 第二条路径 (torch_mlir.fx.export_and_import)允许用户导入任意 Python
可调用对象(nn.Module、函数或方法)的合并 torch.export.ExportedProgram 实例,并输出到 torch dialect mlir 模块。
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
def save_mlir(module, name, ir):
module_strs = str(module)
mlir_name = name + ".mlir"
cwd = os.getcwd()
mlir_path = os.path.join(cwd,"iree_test", ir)
if not os.path.exists(mlir_path):
os.makedirs(mlir_path)
with open(os.path.join(mlir_path, mlir_name), 'w') as f:
f.write(module_strs)
class Transform:
def __init__(self, f, *args, **kwargs):
self.f = f
self.kernel_name = f.__class__.__name__
self.args = args
self.kwargs = kwargs
self.module = None
def get_torchIR(self):
self.module = fx.export_and_import(self.f, *self.args, **self.kwargs)
print("aten ir:")
print(self.module)
save_mlir(self.module, self.kernel_name, "torch-aten")
def lower_linalg(self):
run_pipeline_with_repro_report(
self.module,
(
"builtin.module("
"func.func(torch-decompose-complex-ops),"
"torch-backend-to-linalg-on-tensors-backend-pipeline)"
),
"Lowering TorchFX IR -> Linalg IR",
enable_ir_printing=False,
)
print("linalg ir:")
print(self.module)
save_mlir(self.module, self.kernel_name, "linalg-ir")
def run(self):
self.get_torchIR()
self.lower_linalg()
# return self.module
@run
def test_sigmoid():
class Sigmoid(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x:Tensor) -> Tensor:
return torch.sigmoid(x)
sigmoid = Transform(Sigmoid(), torch.randn(128, 128))
sigmoid.run()
参考
https://mp.weixin.qq.com/s/JENCa_GNGPHhOspGb79ugA
https://pytorch.org/get-started/pytorch-2.0
https://cloud.tencent.com/developer/article/2203301
https://zhuanlan.zhihu.com/p/644590863