【编译系列】Torch.compile()流程解析——3. AOTAutograd

本文是torch.compile()流程解析系列文章,因此很多代码和样例需要结合前面的文章才会更容易理解哦~

AOTAutograd(Ahead-Of-Time Autograd)介绍

上一节TorchDynamo的介绍中,我们解析了torch.compile()是如何捕获计算图并保存为GraphModule的,但在这个过程中只是对整个Python字节码进行了模拟执行、解析并构建FX Graph,相当于只是初步构建了前向计算图,没有捕获训练场景下的反向计算图。在PyTorch中反向计算图的捕获是放在backend compiler里面实现了,以torch.compile的默认backend compiler——inductor为例,在其实现函数compile_fx中,涵盖了AOTAutograd(捕获fw-bw joint graph)、PrimTorch(lowering op)和TorchInductor(图优化、Triton)。下面开始解析backend compiler的默认函数——inductor的函数实现,并梳理出剩下三个组件的原理。

首先介绍AOTAutograd,AOTAutograd是 PyTorch 引入的一种自动求导机制,旨在在模型执行之前预先生成梯度计算的代码。这种方法通过静态分析模型的前向计算图,提前生成反向传播所需的梯度计算逻辑,从而减少运行时的开销,提升训练效率。有了 AOTAutograd,开发者可以做以下事情:

  • 获取反向传播计算图、甚至是正向传播和反向传播联合的计算图;
  • 用不同的后端编译器分别编译正向传播和反向传播计算图;
  • 针对训练 (training) 做正向传播、反向传播联合优化,比如通过在反向传播中重算 (recompute) 来减少正向传播为反向传播保留的 tensor,从而削减内存需求;

总的来说,AOTAutograd的工作流程如下:

  • 基于__torch_dispatch__机制trace正向反向传播,生成联合计算图(joint graph)。
  • 通过decompositions进一步拆解,将FX Graph进一步转换为更低层次的中间表示,即PrimTorch。
  • 通过partition_fn将joint-graph切分成正反向计算图。
  • 调用fw_compiler和bw_compiler对正向、反向计算图分别进行编译,并整合成一个torch.autograd.Function。

torch dispatch

AOTAutograd是基于__torch_dispatch__机制在算子下发执行前获得真正实际执行的op,并构建对应的Proxy,即PyTorch反向传播的计算图是在执行正向过程中动态创建的,这也意味着执行完整的前向过程才能构建出对应的FX Graph,从而在函数正式执行前拿到正反向计算图,实现AOTAutograd,而这一过程也是依赖于前面TorchDynamo捕获的FX Graph这一IR表示。

在正式解析AOTAutograd之前先了解一下__torch_dispatch__机制。PyTorch的核心是一个dispatcher,功能是根据输入tensor的属性把算子dispatch到具体的kernel上,如根据tensor的device属性决定是调用CUDA kernel还是CPU实现,从而综合各项属性算出一个dispatch key决定调用哪个kernel。一个算子在PyTorch中往往要经过多次dispatch,而**__torch_patch__给开发者提供了在算子最终dispatch前获取对应的算子和输入的接口。
pytorch算子dispatch过程

后续AOTAutograd实现的代码逻辑如下,感兴趣的小伙伴也可以看看后面的代码解析部分

AOTAutograd函数调用


Joint Graph

在TorchDynamo那一篇的第3节中提及通过在TorchDynamo构建FX Graph后会调用call_user_compiler()调用backend compiler对计算图进行编译,torch.compile()的默认编译函数实现inductor的入口函数是compile_fx()。

分析compile_fx()的函数调用栈,其核心实现是aot_dispatch_autograd()函数,其主要流程如下:

  1. 调用aot_dispatch_autograd_graph()生成前反向joint graph。
  2. 调用partition_fn进行切分,最后返回包含前、反向计算图的torch.autograd.Function。

首先介绍aot_dispatch_autograd_graph()函数生成joint graph的过程:

  1. 通过create_joint()函数将正反向计算封装成函数,create_joint()根据前向计算结果分析出需要计算梯度的参数以及对应的tangents(梯度权值),然后通过torch.autograd.grad进行反向求导,并将正反向过程封装在函数中返回,作为joint_fn_to_trace。
  2. 由_create_graph()对joint_fn_to_trace函数进行跟踪,核心是调用make_fx()函数在算子dispatch前拿到实际真正执行的op并创建Proxy添加到FX Graph中。

在make_fx()函数中是通过_MakefxTracer.trace()函数对整个函数计算过程进行跟踪并生成GraphModule,GraphModule中包含正反向计算对应的计算图。需要注意的是这里的正反向计算是TorchDynamo graph break对应的子图,即每个子图都会调用一次make_fx生成joint graph。捕获过程主要包括两个核心操作:

  1. 对输入输出的封装:在dispatch_trace()->Tracer.trace()中会为函数参数、局部变量以及输出生成对应的Proxy。其中通过create_args_for_root()->create_proxy()为所有变量(函数参数和局部变量)创建类型为placeholder的Proxy,在create_proxy()中会同步创建Node并将其加入到FX Graph中,并用Proxy封装一下Node。通过create_node()为输出创建类型为 output 的Node,并将其加入到FX Graph中)。
  2. op dispatch的捕获和封装: 在with decompose()上下文管理中通过self.proxy_mode指定了ProxyTorchDispatchMode(用于拦截和自定义张量操作的分发过程,Dispatch Mode 机制允许开发者在张量操作(如加法、矩阵乘法等)被执行时,插入自定义逻辑,以实现诸如调试、性能监控、自定义后端支持等功能,而不需要修改Python的核心代码)。通过重写__torch_dispatch__函数指定op dispatch过程中插入的操作,在ProxyTorchDispatchMode中是对op的decompose(拆解到PrimTorch规定的集合中),同时为op创建类型为call_function()的Proxy。
# aot_dispatch_autograd_graph()函数实现
# ps:这里只展示核心函数调用
def aot_dispatch_autograd_graph(
    flat_fn,
    flat_args: List[Any],
    aot_config: AOTConfig,
    *,
    fw_metadata: ViewAndMutationMeta,
) -> Tuple[torch.fx.GraphModule, Tuple[List[Any], List[Any]], Optional[SubclassMeta]]:
    # traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward.
    # It includes outputs of the original forward, *and* any updated inputs due to input mutations.
    # However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations.
    joint_inputs = (flat_args, fw_metadata.traced_tangents)

    joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config)    # 生成正反向计算,封装成函数

    fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)    # 通过make_fx()函数跟踪joint_fn_to_trace的计算过程生成joint_graph,以torch.fx.GraphModule格式返回
# _create_graph()的核心实现实现
# path:torch/fx/experimental/proxy_tensor.py::class _MakefxTracer
# ps:只展示核心代码实现
def _trace_inner(self, f, *args):
    phs = pytree.tree_map(lambda _: fx.PH, args)  # type: ignore[attr-defined]

    args = _wrap_fake(args)
    func = _wrap_func(f, phs)
    # We disable the autocast cache as the autocast cache causes type conversions on parameters to
    # check a cache, which introduces untracked tensors into the graph
    #
    # We also disable tracing by any other tensor proxy-based tracers except the current. The
    # purpose of `make_fx	
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值