前言
torch1.8中引入, torch1.10中正式发布。
差不多类似于一种语法树规则,通过固定格式的语法描述一个抽象的计算图。
可以借由这个语法树生成标准格式的代码.code
。
可以借由Module实例解析出语法树。
玩过语法树分析的应该都不陌生。
官方说法好像是翻译成IR ( intermediate representation )
IR是高级语言和汇编语言间的中间表达。
6种IR
- placeholder represents a function input. The name attribute specifies the name this value will take on. target is similarly the name of the argument. args holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. kwargs is don’t-care. Placeholders correspond to the function parameters (e.g. x) in the graph printout.
- get_attr retrieves a parameter from the module hierarchy. name is similarly the name the result of the fetch is assigned to. target is the fully-qualified name of the parameter’s position in the module hierarchy. args and kwargs are don’t-care
- call_function applies a free function to some values. name is similarly the name of the value to assign to. target is the function to be applied. args and kwargs represent the arguments to the function, following the Python calling convention
- call_module applies a module in the module hierarchy’s forward() method to given arguments. name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. args and kwargs represent the arguments to invoke the module on, including the self argument.
- call_method calls a method on a value. name is as similar. target is the string name of the method to apply to the self argument. args and kwargs represent the arguments to invoke the module on, including the self argument
- output contains the output of the traced function in its args[0] attribute. This corresponds to the “return” statement in the Graph printout.
example
#source: https://www.jianshu.com/p/bc14593480b9
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
# 打印查看FX的IR
print(symbolic_traced.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
# 通过FX生成的代码,可以视为module中的forward代码
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
Module -> fx.Graph
使用symbolic_trace
symbolic_trace的返回值是fx.GraphModule,GraphModule是 nn.Module的子类。
fx.GraphModule有一个属性graph
,是fx.Graph
类型,代表IR描述的计算图。
import torch
from torch import nn
from torch.fx import symbolic_trace
# 定义一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
# 使用 symbolic_trace 进行符号跟踪
model = SimpleModel()
traced_model = symbolic_trace(model)
# 查看 traced_model 的类型
print(type(traced_model)) # 输出: <class 'torch.fx.graph_module.GraphModule'>
fx.Graph -> Module
使用fx.GraphModule。
需要传入一个module实例,和新的计算图fx.Graph。
新的计算图会替换实例的forward函数。
GraphModule 本身是 nn.Module 的子类,可以直接使用。
from torch import fx
from torch.fx import GraphModule
graph = fx.Graph()
# ... 一些计算图修改
new_model = GraphModule(model, graph, "revisedModel")
out = new_model(inputs)
fx to script
script的目的是为了优化模型,脱离python使用。
fx的目的是为了修改模型,实现python2python的代码转换。
fx产生代码是使用原生的python解释器运行。
script产生的模型使用一个虚拟机VM在C++中运行。
fx追踪forward得到动态图。
script转为静态图。
转成script的操作是不可逆的。
# 已经是graph_module
scripted_module = torch.jit.script(graph_module)
另一种工作流,从IR开始手写fx.Graph
众所周知fx.Graph通过IR描述计算图。
可以手写。
这种方式就和tensorflow很类似了。
预定义计算图。
from torch import fx
graph = fx.Graph() # 空计算图
tracer = fx.proxy.GraphAppendingTracer(graph) # 空tracer,名字叫appending应该懂了吧。
x1s = fx.Proxy(graph.placeholder("x1", torch.Tensor), tracer=tracer)
x2s = fx.Proxy(graph.placeholder("x2", torch.Tensor), tracer=tracer)
weights = fx.Proxy(graph.placeholder("w", torch.Tensor), tracer=tracer)
# 然后这三个proxy tensor可以假装和正常的tensor一样使用
out = (x1s+x2s)*weights # 但是返回的类型还是proxy。
out = out.view(-1) # 还是proxy
# 最后调用output IR来结束计算图
# 第一个入参必须是proxy.node
graph.output(out.node, torch.Tensor)
速度对比
e3nn的实现
irreps_in1 = o3.Irreps("1x0e+2x1o+3x2e")
irreps_in2 = o3.Irreps("1x0e+2x1o+3x2e")
irreps_out = o3.Irreps("1x0e+2x1o+3x2e")
instr = [
(i_1, i_2, i_out, "uvw", True, 1.0)
for i_1, (_, ir_1) in enumerate(irreps_in1)
for i_2, (_, ir_2) in enumerate(irreps_in2)
for i_out, (_, ir_out) in enumerate(irreps_out)
if ir_out in ir_1 * ir_2
]
准备3个model
model0,e3nn原生的TP,也是torch.fx写的,默认带script优化。
model1,使用 torch.fx 写的计算图,加上script。我改了一点点。
model2,使用 正常的eager mode写的forward nn.Module,带 cuda.stream优化
model3,使用 正常的eager mode写的forward nn.Module。
model0 = TensorProduct(irreps_in1,irreps_in2,irreps_out,instr,_specialized_code=True,_optimize_einsums=True)
model1 = QTensorProduct1(irreps_in1,irreps_in2,irreps_out,instr)
model2 = QTensorProduct2(irreps_in1,irreps_in2,irreps_out,instr)
model3 = QTensorProduct3(irreps_in1,irreps_in2,irreps_out,instr)
a = torch.ones((32, irreps_in1.dim))
b = torch.ones((32, irreps_in2.dim))
cpu上计时
# warm up
model1(a,b)
model2(a,b)
model3(a,b)
with qt.Timer(''):
for i in range(1000):
model1(a,b)
with qt.Timer(''):
for i in range(1000):
model2(a,b)
with qt.Timer(''):
for i in range(1000):
model3(a,b)
Execution time: 0.89 seconds # 不带scrip会是1.3左右
Execution time: 0.83 seconds # 改写版本比e3nn原生略好一点点。
Execution time: 3.81 seconds
Execution time: 3.79 seconds
cpu + compile
opt_model1 = torch.compile(model1)
opt_model2 = torch.compile(model2)
opt_model3 = torch.compile(model3)
#warmup
opt_model1(a,b)
opt_model2(a,b)
opt_model3(a,b)
with qt.Timer(''):
for i in range(1000):
opt_model1(a,b)
with qt.Timer(''):
for i in range(1000):
opt_model2(a,b)
with qt.Timer(''):
for i in range(1000):
opt_model3(a,b)
Execution time: 0.84 seconds
Execution time: 4.23 seconds # stream操作被compile干掉了。
Execution time: 4.24 seconds # 比不compile更慢
gpu
cu_model1 = model1.cuda()
cu_model2 = model2.cuda()
cu_model3 = model3.cuda()
cuda_a = a.cuda()
cuda_b = b.cuda()
# warm up
cu_model1(cuda_a, cuda_b)
cu_model2(cuda_a, cuda_b)
cu_model3(cuda_a, cuda_b)
with qt.Timer(''):
for i in range(1000):
cu_model1(cuda_a, cuda_b)
with qt.Timer(''):
for i in range(1000):
cu_model2(cuda_a, cuda_b)
with qt.Timer(''):
for i in range(1000):
cu_model3(cuda_a, cuda_b)
Execution time: 1.07 seconds # 比cpu上还慢...
Execution time: 4.14 seconds # 带stream的版本比cpu快一点了,但还是不如直接算。
Execution time: 4.63 seconds # 比cpu慢。
gpu+ compi;e
opt_cumoedl1_inductor = torch.compile(cu_model1, backend='inductor')
opt_cumoedl1_cugraph = torch.compile(cu_model1, backend='cudagraphs')
opt_cumoedl3_inductor = torch.compile(cu_model3, backend='inductor')
opt_cumoedl3_cugraph = torch.compile(cu_model3, backend='cudagraphs')
opt_cumoedl2_inductor = torch.compile(cu_model2, backend='inductor')
opt_cumoedl2_cugraph = torch.compile(cu_model2, backend='cudagraphs')
opt_cumoedl3_inductor = torch.compile(cu_model3, backend='inductor')
opt_cumoedl3_cugraph = torch.compile(cu_model3, backend='cudagraphs')
opt_cumoedl1_inductor(cuda_a, cuda_b)
opt_cumoedl1_cugraph(cuda_a, cuda_b)
opt_cumoedl2_inductor(cuda_a, cuda_b)
opt_cumoedl2_cugraph(cuda_a, cuda_b)
opt_cumoedl3_inductor(cuda_a, cuda_b)
opt_cumoedl3_cugraph(cuda_a, cuda_b)
with qt.Timer(''):
for i in range(1000):
opt_cumoedl1_inductor(cuda_a, cuda_b)
with qt.Timer(''):
for i in range(1000):
opt_cumoedl1_cugraph(cuda_a, cuda_b)
with qt.Timer(''):
for i in range(1000):
opt_cumoedl2_inductor(cuda_a, cuda_b)
with qt.Timer(''):
for i in range(1000):
opt_cumoedl2_cugraph(cuda_a, cuda_b)
with qt.Timer(''):
for i in range(1000):
opt_cumoedl3_inductor(cuda_a, cuda_b)
with qt.Timer(''):
for i in range(1000):
opt_cumoedl3_cugraph(cuda_a, cuda_b)
Execution time: 1.09 seconds # 和不加compile差不多,略慢一点点。
Execution time: 1.08 seconds #
Execution time: 4.58 seconds # 比不加compile慢
Execution time: 4.65 seconds
Execution time: 4.99 seconds # 比不加compile慢
Execution time: 5.03 seconds # 看来inductor总体还是好一点。
看一下fx然后script后的一次Tensorproduct代码
def forward(self,
x1: Tensor,
x2: Tensor,
w: Tensor) -> Tensor:
_0 = __torch__.torch.functional.___torch_mangle_0.tensordot
empty = torch.empty(annotate(List[int], []), dtype=None, layout=None, device=torch.device("cpu"))
getattr_1 = torch.size(x1)
getitem = torch.slice(getattr_1, 0, -1)
expand = torch.expand(empty, getitem)
getattr_2 = torch.size(x2)
getitem_1 = torch.slice(getattr_2, 0, -1)
expand_1 = torch.expand(empty, getitem_1)
broadcast_tensors = torch.broadcast_tensors([expand, expand_1])
getitem_2 = broadcast_tensors[0]
getattr_3 = torch.size(getitem_2)
add = torch.add(getattr_3, [-1])
expand_2 = torch.expand(x1, add)
expand_3 = torch.expand(x2, add)
add_1 = torch.add(getattr_3, [22])
reshape = torch.reshape(expand_2, [-1, 22])
reshape_1 = torch.reshape(expand_3, [-1, 22])
getattr_4 = torch.size(reshape)
getitem_3 = getattr_4[0]
reshape_2 = torch.reshape(w, [-1, 103])
_1 = torch.slice(reshape, 0, 0, 9223372036854775807)
getitem_4 = torch.slice(_1, 1, 0, 1)
reshape_3 = torch.reshape(getitem_4, [getitem_3, 1, 1])
_2 = torch.slice(reshape, 0, 0, 9223372036854775807)
getitem_5 = torch.slice(_2, 1, 1, 7)
reshape_4 = torch.reshape(getitem_5, [getitem_3, 2, 3])
_3 = torch.slice(reshape, 0, 0, 9223372036854775807)
getitem_6 = torch.slice(_3, 1, 7, 22)
reshape_5 = torch.reshape(getitem_6, [getitem_3, 3, 5])
_4 = torch.slice(reshape_1, 0, 0, 9223372036854775807)
getitem_7 = torch.slice(_4, 1, 0, 1)
reshape_6 = torch.reshape(getitem_7, [getitem_3, 1, 1])
_5 = torch.slice(reshape_1, 0, 0, 9223372036854775807)
getitem_8 = torch.slice(_5, 1, 1, 7)
reshape_7 = torch.reshape(getitem_8, [getitem_3, 2, 3])
_6 = torch.slice(reshape_1, 0, 0, 9223372036854775807)
getitem_9 = torch.slice(_6, 1, 7, 22)
reshape_8 = torch.reshape(getitem_9, [getitem_3, 3, 5])
_7 = torch.slice(reshape_2, 0, 0, 9223372036854775807)
getitem_10 = torch.slice(_7, 1, 0, 1)
reshape_9 = torch.reshape(getitem_10, [1, 1, 1])
reshape_10 = torch.reshape(reshape_3, [getitem_3, 1])
reshape_11 = torch.reshape(reshape_6, [getitem_3, 1])
einsum_1 = torch.einsum("cb,ca->cba", [reshape_11, reshape_10])
tensordot = _0(einsum_1, reshape_9, ([1, 2], [1, 0]), None, )
reshape_12 = torch.reshape(tensordot, [getitem_3, 1])
_8 = torch.slice(reshape_2, 0, 0, 9223372036854775807)
getitem_11 = torch.slice(_8, 1, 1, 5)
reshape_13 = torch.reshape(getitem_11, [1, 2, 2])
reshape_14 = torch.reshape(reshape_3, [getitem_3, 1])
einsum_3 = torch.einsum("dca,db->dcab", [reshape_7, reshape_14])
tensordot_1 = _0(einsum_3, reshape_13, ([1, 3], [1, 0]), None, )
permute = torch.permute(tensordot_1, [0, 2, 1])
reshape_15 = torch.reshape(permute, [getitem_3, 6])
_9 = torch.slice(reshape_2, 0, 0, 9223372036854775807)
getitem_12 = torch.slice(_9, 1, 5, 14)
reshape_16 = torch.reshape(getitem_12, [1, 3, 3])
reshape_17 = torch.reshape(reshape_3, [getitem_3, 1])
einsum_5 = torch.einsum("dca,db->dcab", [reshape_8, reshape_17])
tensordot_2 = _0(einsum_5, reshape_16, ([1, 3], [1, 0]), None, )
permute_1 = torch.permute(tensordot_2, [0, 2, 1])
reshape_18 = torch.reshape(permute_1, [getitem_3, 15])
_10 = torch.slice(reshape_2, 0, 0, 9223372036854775807)
getitem_13 = torch.slice(_10, 1, 14, 18)
reshape_19 = torch.reshape(getitem_13, [2, 1, 2])
reshape_20 = torch.reshape(reshape_6, [getitem_3, 1])
einsum_7 = torch.einsum("dc,dba->dcba", [reshape_20, reshape_4])
tensordot_3 = _0(einsum_7, reshape_19, ([1, 2], [1, 0]), None, )
permute_2 = torch.permute(tensordot_3, [0, 2, 1])
reshape_21 = torch.reshape(permute_2, [getitem_3, 6])
_11 = torch.slice(reshape_2, 0, 0, 9223372036854775807)
getitem_14 = torch.slice(_11, 1, 18, 22)
reshape_22 = torch.reshape(getitem_14, [2, 2, 1])
einsum_8 = torch.einsum("dca,dba->dcb", [reshape_7, reshape_4])
mul = torch.mul(reshape_22, 0.28867513459481292)
tensordot_4 = _0(einsum_8, mul, ([1, 2], [1, 0]), None, )
reshape_23 = torch.reshape(tensordot_4, [getitem_3, 1])
_12 = torch.slice(reshape_2, 0, 0, 9223372036854775807)
getitem_15 = torch.slice(_12, 1, 22, 34)
reshape_24 = torch.reshape(getitem_15, [2, 2, 3])
_w3j_1_1_2 = self._w3j_1_1_2
tensordot_5 = _0(reshape_4, _w3j_1_1_2, ([2], [0]), None, )
einsum_9 = torch.einsum("ecab,eda->ecbd", [tensordot_5, reshape_7])
mul_15 = torch.mul(reshape_24, 1.9364916731037085)
tensordot_6 = _0(einsum_9, mul_15, ([3, 1], [1, 0]), None, )
permute_3 = torch.permute(tensordot_6, [0, 2, 1])
reshape_25 = torch.reshape(permute_3, [getitem_3, 15])
_13 = torch.slice(reshape_2, 0, 0, 9223372036854775807)
getitem_16 = torch.slice(_13, 1, 34, 46)
reshape_26 = torch.reshape(getitem_16, [2, 3, 2])
_w3j_1_2_1 = self._w3j_1_2_1
tensordot_7 = _0(reshape_8, reshape_26, ([1], [1]), None, )
tensordot_8 = _0(reshape_4, _w3j_1_2_1, ([2], [0]), None, )
einsum_10 = torch.einsum("ecab,eacd->edb", [tensordot_8, tensordot_7])
reshape_27 = torch.reshape(einsum_10, [getitem_3, 6])
_14 = torch.slice(reshape_2, 0, 0, 9223372036854775807)
getitem_17 = torch.slice(_14, 1, 46, 55)
reshape_28 = torch.reshape(getitem_17, [3, 1, 3])
reshape_29 = torch.reshape(reshape_6, [getitem_3, 1])
einsum_12 = torch.einsum("dc,dba->dcba", [reshape_29, reshape_5])
tensordot_9 = _0(einsum_12, reshape_28, ([1, 2], [1, 0]), None, )
permute_4 = torch.permute(tensordot_9, [0, 2, 1])
reshape_30 = torch.reshape(permute_4, [getitem_3, 15])
_15 = torch.slice(reshape_2, 0, 0, 9223372036854775807)
getitem_18 = torch.slice(_15, 1, 55, 67)
reshape_31 = torch.reshape(getitem_18, [3, 2, 2])
_w3j_2_1_1 = self._w3j_2_1_1
tensordot_10 = _0(reshape_5, reshape_31, ([1], [0]), None, )
tensordot_11 = _0(reshape_7, _w3j_2_1_1, ([2], [1]), None, )
einsum_13 = torch.einsum("ecab,eacd->edb", [tensordot_11, tensordot_10])
reshape_32 = torch.reshape(einsum_13, [getitem_3, 6])
_16 = torch.slice(reshape_2, 0, 0, 9223372036854775807)
getitem_19 = torch.slice(_16, 1, 67, 76)
reshape_33 = torch.reshape(getitem_19, [3, 3, 1])
einsum_14 = torch.einsum("dca,dba->dcb", [reshape_8, reshape_5])
tensordot_12 = _0(einsum_14, reshape_33, ([1, 2], [1, 0]), None, )
mul_16 = torch.mul(tensordot_12, 0.14907119849998596)
reshape_34 = torch.reshape(mul_16, [getitem_3, 1])
_17 = torch.slice(reshape_2, 0, 0, 9223372036854775807)
getitem_20 = torch.slice(_17, 1, 76, 103)
reshape_35 = torch.reshape(getitem_20, [3, 3, 3])
_w3j_2_2_2 = self._w3j_2_2_2
mul_17 = torch.mul(reshape_35, 1.2909944487358056)
tensordot_13 = _0(reshape_5, mul_17, ([1], [0]), None, )
tensordot_14 = _0(reshape_8, _w3j_2_2_2, ([2], [1]), None, )
einsum_15 = torch.einsum("ecab,eacd->edb", [tensordot_14, tensordot_13])
reshape_36 = torch.reshape(einsum_15, [getitem_3, 15])
add_2 = torch.add(reshape_12, reshape_23)
add_3 = torch.add(add_2, reshape_34)
add_4 = torch.add(reshape_15, reshape_21)
add_5 = torch.add(add_4, reshape_27)
add_6 = torch.add(add_5, reshape_32)
add_7 = torch.add(reshape_18, reshape_25)
add_8 = torch.add(add_7, reshape_30)
add_9 = torch.add(add_8, reshape_36)
cat = torch.cat([add_3, add_6, add_9], 1)
return torch.reshape(cat, add_1)
fx 更厉害的原因
我看了一下。
从fx手写计算图然后script,是比写一个nn.Module再script更强的方案。
1 自定义class的支持
script在处理nn.Module的forward部分的时候,假定绝大部分非Tensor操作都在init里面完成了,forward里只考虑tensor操作,把所有变量默认当成Tensor。
如果你想引入自定义类,很容易在script时发生错误。
或者另一种情况,你需要在 forward的时候使用保存下来的slice。
而 fx 会对第三方类支持的更友好一点。
2 几乎所有运算的支持
通过 f-string 获得key,然后根据key动态获取 buffer,在正常Module->script里是不行的。
但是在fx的预定义图里可以。
3 控制流优化
经常会看到一个写法,在forward里面,根据self.aaa 来选择运算方式。
这个东西在fx.graph里面会自动只保留一条路径。
而在nn.Module的forward里面,每次还是需要if一下。
stream控制
目前的fx捕捉不了 cuda.stream。
class MyModel(torch.nn.Module):
def forward(self, x):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
return x * 2
model = MyModel()
graphModel = torch.fx.symbolic_trace(model)
print(graphModel.code)
# stream相关代码消失了。
# def forward(self, x):
# mul = x * 2; x = None
# return mul
所有with相关的上下文好像都不会被fx捕捉。
但是似乎可以在script后手工添加。
参考script的实现。
可以把 stream转成 call_function
IR。
streams = annotate(List[__torch__.torch.classes.cuda.Stream], [])
for _58 in range(num_instr):
_59 = __torch__.torch.classes.cuda.Stream.__new__(__torch__.torch.classes.cuda.Stream)
_60 = (_59).__init__(None, 0, )
_61 = torch.append(streams, _59)
out_arrs = annotate(List[List[Tensor]], [])
for _62 in range(_num_out):
_63 = torch.append(out_arrs, annotate(List[Tensor], []))
for i in range(num_instr):
_64 = __torch__.torch.cuda.stream(streams[i], )
with _64:
# balabala
result8 = torch.mul(result, path_weight)
result9 = torch.reshape(result8, [batch_numel, _dim_out])
_94 = torch.append(out_arrs[i_out], result9)
手写fx计算图控制node结点
import torch
from torch import fx
graph = fx.Graph()
# = Function definitions =
tracer = fx.proxy.GraphAppendingTracer(graph)
x1 = fx.Proxy(graph.placeholder("x1", torch.Tensor), tracer=tracer)
x2 = x1 + x1
x3 = x2 * x1
x4 = x3 * x2
# 打印图
def print_node(node):
print(f"Node name: {node.name}, Target: {node.target}, Inputs: {node.args} {node.kwargs}")
for node in graph.nodes:
print_node(node)
# Node name: x1, Target: x1, Inputs: ()
# Node name: add, Target: <built-in function add>, Inputs: (x1, x1) ()
# Node name: mul, Target: <built-in function mul>, Inputs: (add, x1) ()
# Node name: mul_1, Target: <built-in function mul>, Inputs: (mul, add) ()
# 单独修改x3结点的名字
print_node(x3.node)
x3.node.name = "x3"
print_node(x3.node)
# Node name: mul, Target: <built-in function mul>, Inputs: (add, x1) ()
# Node name: x3, Target: <built-in function mul>, Inputs: (add, x1) ()
# 再次打印图
for node in graph.nodes:
print_node(node)
# Node name: x1, Target: x1, Inputs: ()
# Node name: add, Target: <built-in function add>, Inputs: (x1, x1) ()
# Node name: x3, Target: <built-in function mul>, Inputs: (add, x1) ()
# Node name: mul_1, Target: <built-in function mul>, Inputs: (x3, add) ()
# 看最后一行可以发现,我们改了x3结点的名字,受影响的 `mul_1`结点的inputs也相应修改了。