(2)Dynamo

Dynamo


目的是把用户程序抓成图,然后编译这段图优化加速,然后执行这段加速代码替代用户的代码执行、

这里会改变实际执行的字节码,具体后面细看,简单来说就是把源代码部分抓成图,然后编译,返回这部分的编译函数来调用。

# 一个普通的函数
def fn(a, b):
    x = a + b
    x = x / 2.0
    if x.sum() < 0:
        return x * -1.0
    return x
 
# torchdynamo 函数接口
with torchdynamo.optimize(custom_compiler):   
   fn(torch.randn(10), torch.randn(10))

源码生成的字节码

 # x = a + b
 0  LOAD_FAST 0 (a)
 2  LOAD_FAST 1 (b)
 4  BINARY_ADD
 6  STORE_FAST 2 (x)

 # x = x / 2.0
 8  LOAD_FAST 2 (x)
 10 LOAD_CONST 1 (2.0)
 12 BINARY_TRUE_DIVIDE
 14 STORE_FAST 2 (x)

 # if x.sum() < 0:
 16 LOAD_FAST 2 (x)
 18 LOAD_METHOD 0 (sum)
 20 CALL_METHOD 0
 22 LOAD_CONST 2 (0)
 24 COMPARE_OP 0 (<)
 26 POP_JUMP_IF_FALSE 36
 
 # return x * -1.0
 28 LOAD_FAST 2 (x)
 30 LOAD_CONST 3 (-1.0)
 32 BINARY_MULTIPLY
 34 RETURN_VALUE

 # return x
 36 LOAD_FAST 2 (x)
 38 RETURN_VALUE

经过改写后的

 # x = a + b
 # x = x / 2.0
 # x.sum() < 0
 # 上面两行被转换成了 __compiled_fn_0
 # __compiled_fn_0 会返回 x 和 x.sum() < 0 组成的 tuple
 0  LOAD_GLOBAL 1 (__compiled_fn_0)
 2  LOAD_FAST 0 (a)
 4  LOAD_FAST 1 (b)
 6  CALL_FUNCTION 2
 8  UNPACK_SEQUENCE 2
 10 STORE_FAST 2 (x)
 12 POP_JUMP_IF_FALSE 22
 
 # x * -1.0 被转换成了 __compiled_fn_1 
 14 LOAD_GLOBAL 2 (__compiled_fn_1)
 16 LOAD_FAST 2 (x)
 18 CALL_FUNCTION 1
 20 RETURN_VALUE

 # return x
 22 LOAD_FAST 2 (x)
 24 RETURN_VALUE

执行顺序

以下是每个字节码指令的解释及其在调用栈中的表现:

LOAD_GLOBAL 1 (__compiled_fn_0):

  • 从全局命名空间加载函数 __compiled_fn_0。
  • 调用栈状态:[__compiled_fn_0]

LOAD_FAST 0 (a):

  • 加载局部变量 a。
  • 调用栈状态:[__compiled_fn_0, a]

LOAD_FAST 1 (b):

  • 加载局部变量 b。
  • 调用栈状态:[__compiled_fn_0, a, b]

CALL_FUNCTION 2:

  • 调用 __compiled_fn_0,传入 a 和 b 两个参数。
  • 返回一个包含 x 和 x.sum() < 0 的 tuple。
  • 调用栈状态:[(x, condition)]

UNPACK_SEQUENCE 2:

  • 解包 tuple。
  • 调用栈状态:[x_value, condition]

STORE_FAST 2 (x):

  • 将 x 存储到局部变量槽中。
  • 调用栈状态:[condition]

POP_JUMP_IF_FALSE 22:

  • 如果 condition 为 False,跳转到字节码偏移量 22。
  • 调用栈状态:[]

LOAD_GLOBAL 2 (__compiled_fn_1):

  • 从全局命名空间加载函数 __compiled_fn_1。
  • 调用栈状态:[__compiled_fn_1]

LOAD_FAST 2 (x):

  • 加载局部变量 x。
  • 调用栈状态:[__compiled_fn_1, x]

CALL_FUNCTION 1:

  • 调用 __compiled_fn_1,传入 x 作为参数。
  • 调用栈状态:[result]

RETURN_VALUE:

  • 返回 result。
  • 调用栈状态:[]

LOAD_FAST 2 (x):

  • 如果 condition 为 False,跳转到这里并加载 x。
  • 调用栈状态:[x]

RETURN_VALUE:
返回 x。
调用栈状态:[]

两个函数的图

__compiled_fn_0:
opcode         name     target                       args              kwargs
-------------  -------  ---------------------------  ----------------  --------
placeholder    a_0      a_0                          ()                {}
placeholder    b_1      b_1                          ()                {}
call_function  add      <built-in function add>      (a_0, b_1)        {}
call_function  truediv  <built-in function truediv>  (add, 2.0)        {}
call_method    sum_1    sum                          (truediv,)        {}
call_function  lt       <built-in function lt>       (sum_1, 0)        {}
output         output   output                       ((truediv, lt),)  {}

__compiled_fn_1:
opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    x_4     x_4                      ()           {}
call_function  mul     <built-in function mul>  (x_4, -1.0)  {}
output         output  output                   (mul,)       {}
  • 5
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值