Dynamo
目的是把用户程序抓成图,然后编译这段图优化加速,然后执行这段加速代码替代用户的代码执行、
这里会改变实际执行的字节码,具体后面细看,简单来说就是把源代码部分抓成图,然后编译,返回这部分的编译函数来调用。
# 一个普通的函数
def fn(a, b):
x = a + b
x = x / 2.0
if x.sum() < 0:
return x * -1.0
return x
# torchdynamo 函数接口
with torchdynamo.optimize(custom_compiler):
fn(torch.randn(10), torch.randn(10))
源码生成的字节码
# x = a + b
0 LOAD_FAST 0 (a)
2 LOAD_FAST 1 (b)
4 BINARY_ADD
6 STORE_FAST 2 (x)
# x = x / 2.0
8 LOAD_FAST 2 (x)
10 LOAD_CONST 1 (2.0)
12 BINARY_TRUE_DIVIDE
14 STORE_FAST 2 (x)
# if x.sum() < 0:
16 LOAD_FAST 2 (x)
18 LOAD_METHOD 0 (sum)
20 CALL_METHOD 0
22 LOAD_CONST 2 (0)
24 COMPARE_OP 0 (<)
26 POP_JUMP_IF_FALSE 36
# return x * -1.0
28 LOAD_FAST 2 (x)
30 LOAD_CONST 3 (-1.0)
32 BINARY_MULTIPLY
34 RETURN_VALUE
# return x
36 LOAD_FAST 2 (x)
38 RETURN_VALUE
经过改写后的
# x = a + b
# x = x / 2.0
# x.sum() < 0
# 上面两行被转换成了 __compiled_fn_0
# __compiled_fn_0 会返回 x 和 x.sum() < 0 组成的 tuple
0 LOAD_GLOBAL 1 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 2
10 STORE_FAST 2 (x)
12 POP_JUMP_IF_FALSE 22
# x * -1.0 被转换成了 __compiled_fn_1
14 LOAD_GLOBAL 2 (__compiled_fn_1)
16 LOAD_FAST 2 (x)
18 CALL_FUNCTION 1
20 RETURN_VALUE
# return x
22 LOAD_FAST 2 (x)
24 RETURN_VALUE
执行顺序
以下是每个字节码指令的解释及其在调用栈中的表现:
LOAD_GLOBAL 1 (__compiled_fn_0):
- 从全局命名空间加载函数 __compiled_fn_0。
- 调用栈状态:[__compiled_fn_0]
LOAD_FAST 0 (a):
- 加载局部变量 a。
- 调用栈状态:[__compiled_fn_0, a]
LOAD_FAST 1 (b):
- 加载局部变量 b。
- 调用栈状态:[__compiled_fn_0, a, b]
CALL_FUNCTION 2:
- 调用 __compiled_fn_0,传入 a 和 b 两个参数。
- 返回一个包含 x 和 x.sum() < 0 的 tuple。
- 调用栈状态:[(x, condition)]
UNPACK_SEQUENCE 2:
- 解包 tuple。
- 调用栈状态:[x_value, condition]
STORE_FAST 2 (x):
- 将 x 存储到局部变量槽中。
- 调用栈状态:[condition]
POP_JUMP_IF_FALSE 22:
- 如果 condition 为 False,跳转到字节码偏移量 22。
- 调用栈状态:[]
LOAD_GLOBAL 2 (__compiled_fn_1):
- 从全局命名空间加载函数 __compiled_fn_1。
- 调用栈状态:[__compiled_fn_1]
LOAD_FAST 2 (x):
- 加载局部变量 x。
- 调用栈状态:[__compiled_fn_1, x]
CALL_FUNCTION 1:
- 调用 __compiled_fn_1,传入 x 作为参数。
- 调用栈状态:[result]
RETURN_VALUE:
- 返回 result。
- 调用栈状态:[]
LOAD_FAST 2 (x):
- 如果 condition 为 False,跳转到这里并加载 x。
- 调用栈状态:[x]
RETURN_VALUE:
返回 x。
调用栈状态:[]
两个函数的图
__compiled_fn_0:
opcode name target args kwargs
------------- ------- --------------------------- ---------------- --------
placeholder a_0 a_0 () {}
placeholder b_1 b_1 () {}
call_function add <built-in function add> (a_0, b_1) {}
call_function truediv <built-in function truediv> (add, 2.0) {}
call_method sum_1 sum (truediv,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
__compiled_fn_1:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder x_4 x_4 () {}
call_function mul <built-in function mul> (x_4, -1.0) {}
output output output (mul,) {}