TorchDynamo初探:Python ByteCode的动态修改

TorchDynamo是PyTorch实验性的JIT编译接口,它允许在运行时动态修改Python ByteCode,以实现更高效的计算图优化。不同于基于trace tensor或AST解析的计算图生成,TorchDynamo在CPython的ByteCode阶段介入,提供FX Graph接口供用户自定义计算逻辑。这种方式避免了传统静态编译的限制,支持所有Python语法,且开销较低。文章介绍了CPython的执行流程、TorchDynamo的工作原理以及其实现方式,展示了如何在ByteCode层面实现动态优化。
摘要由CSDN通过智能技术生成

87ea1c847e40dff79a5050335be49924.jpeg


作者|strint

1
背景

深度学习框架编译优化时,需要先根据计算逻辑形成一个逻辑计算图,然后再改写计算图,最后执行改写后的计算图。其中生成逻辑计算图方式有两种。

一种计算图生成是基于 trace tensor 的,跟踪 tensor 的执行路径。tensor 执行时,基于函数重载,可以落到支持 tensor 计算的框架自定义函数,该函数一般是 c++ 层的。c++ 层的自定义函数中,功能是用于生成一个 Operation 的符号表达。比如一个对于加法运算,trace 就是记录一个符号化的加法算子。如此一连串的运算就被转换了符号化的计算图。

另外一种计算图生成是基于 AST(抽象语法树) 解析的。在代码执行前,直接根据 Python 文本代码得到 Python AST,然后根据 AST 来翻译成计算图(也叫做中间代码 IR)。

Python(特指 CPython)解释器执行,第一阶段会先把 Python 源码解析成 AST,第二阶段根据 AST 生成和优化 ByteCode(字节码),第三阶段在虚拟机中执行 ByteCode。

基于 AST 解析的计算图生成,发生在这里的第一阶段;基于 trace tensor 的计算图生成,发生在第三阶段之后。

TorchDynamo 特别的地方在于其工作在第二阶段,动态修改 Python ByteCode,这样第三阶段执行的已经是修改后的 ByteCode了。

2

TorchDynamo 概述

TorchDynamo 是 PyTorch 新实验的 JIT 编译接口,支持使用 Python 在运行时修改动态执行逻辑,修改的时机是 CPython 的 ByteCode 执行前。这个思想类似 DynamoRIO(https://dynamorio.org) 项目,DynamoRIO 可以动态的修改 x86 机器码。

CPython 的每次函数调用会生成一个 Frame(或者叫 Stack),Frame 中带有的代码部分就是 ByteCode。CPython 运行时支持基于现有的 Frame 去设置一个自定义的 Frame,然后后面执行的就是自定义的 Frame。

TorchDynamo 的工作原理就是在运行时设置一个自定义的 Frame,该 Frame 中的 ByteCode 支持 CallBack 到 Python 层去修改。其提供的典型的修改接口是 FX Graph,也就是说 TorchDynamo 会分析 ByteCode,生成对应的 FX Graph,然后提供 FX Graph 的接口供用户自定义计算图。这种做法有如下优点:

  • 可以支持所有的 Python 语法,因为如果在自定义 Frame 过程中的任何一点发现不支持,都可以选择不修改 Frame 而回退到原 Frame;

  • 开销少,劫持发生在 Python 执行比较早的阶段(ByteCode 生成和优化阶段),而非 Python ByteCode 执行后的阶段,有时可以减少 Python ByteCode 的执行开销(猜测如果很多次 ByteCode 层面的函数调用被融合层成一次函数调用,的确可以缩减开销);

  • 可以做到不增加编译带来的延迟(之前的基于 tensor trace 或者 ast 解析的做法,一般都有先编译执行所以编译开销无法掩盖,但是改写 ByteCode 这个做法,猜测是可以在识别出热点代码后,单独开一个线程去做编译,而不影响主线程工作。Python ByteCode 改写的 API 中有这种延迟编译的样例,peps.python.org/pep-052 )。

之前计算图生成机制(基于 trace tensor、基于 AST 解析的)中的几个问题,得到了缓解:

  • 存在无法静态化的操作,之前一般需要显式的移除静态化作用域,现在总是允许不做编译,直接执行原 Python 代码,这样使得静态化标注变得简单;

  • 打开静态图编译优化,之前编译时一般无法掩盖,现在有办法部分掩盖;

  • 动态 shape 问题,因为有了编译时和运行时的掩盖,也可以得到缓解。

这种尽量优化、动态优化的设计,最大程度了照顾了代码开发的体验,让编译优化上手变得更简单了。这是 TorchDynamo 带来的最主要的好处。这种做法非常符合 PyTorch 的 Python First、Eager First、User Experience First的偏好。但是这个设计对于寻求最好的性能、最方便的静态化部署这两个目标并没有改善。

3

CPython 的标准执行流程

上文提到了 CPython 的执行从 Python 文本代码,到 AST,到 ByteCode。这里用一个示例展开看一下。Python 的标准组件非常易用,可以在 Python 层用 ast 组件来查看 AST,可以用 compile 内置函数来编译 ByteCode,可以用 exec 系统函数来执行 ByteCode。我们先在代码开头导入相关组件:

import ast
import dis
import sys

然后我们构造一个 python 代码,可以看到 src_code 就是普通的字符串。其中包含了一段普通的 python 内置的乘法,一段深度学习的 tensor scalar 加法,最后一段是当前Python Frame 中的 ByteCode 关联对象的打印(用于一个检验,后面会提到)。

print("=== source code ===")
src_code = """
# normal python operation
x = 1
x = x * 2


# tensor operation
y = dl_framework.ones((1, 2))
z = x + y
print(z)


# print python frame
f = sys._getframe()
# print the code object
print(f.f_code)
"""
print(src_code)

然后使用 ast 组件来生成这段代码的 AST。

print("=== source code to ast ===")
# 把源代码解析成 AST
ast_obj = ast.parse(src_code)
# 打印 AST
print(ast.dump(ast_obj))

可以得到 AST,这里展示的结果额外做了格式化,另外删减掉了和计算逻辑无关的打印 frame 的部分,代码和其 AST 的对应关系参见注释。AST解析是纯文本层面的,`dl_framework` 还没有被 import 进来,AST解析仍然可以正常工作。AST 基本是一个多叉树的结构,每个节点对应一个表达式,节点子节点代表子表达式。以 `x = x + 2` 为例,Assign 是一个节点,是赋值运算,被赋值的是 `x`,赋值的值是一个二元乘法运算。

Module(body=[
  # x = 1
  Assign(targets=[Name(id='x', ctx=Store())],
         value=Constant(value=1, kind=None),
         type_comment=None),


  # x = x * 2
  Assign(targets=[Name(id='x', ctx=Store())],
         value=BinOp(left=Name(id='x', ctx=Load()), op=Mult(), right=Constant(value=2, kind=None)), type_comment=None),
  
  # y = dl_framework.ones((1, 2))
  Assign(targets=[Name(id='y', ctx=Store())],
         # dl_framework.ones((1, 2))
         
  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值