作者 | 刘耀辉
1
背景
近年来,量化感知训练是一个较为热点的问题,可以大大优化量化后训练造成精度损失的问题,使得训练过程更加高效。
Torch.fx在这一问题上走在了前列,使用纯Python语言实现了对于Torch.nn.Module的解析和向IR的转换,也可以提供变换后的IR对应的Python代码,在外部则是提供了简洁易用的API,大大方便了量化感知训练过程的搭建。此外,Torch.fx也有助于消除动态图和静态图之间的Gap,可以比较方便地对图进行操作以及进行算子融合。
OneFlow紧随其后添加了针对OneFlow的fx,即One-fx,在安装One-fx之后,用户可以直接调用oneflow.fx,也可以直接通过import onefx as fx进行使用。
one-fx地址:
https://github.com/Oneflow-Inc/one-fx
One-fx实现代码中绝大部分是对于Torch.fx的fork,但根据OneFlow和PyTorch之间存在的差别进行了一些适配或优化。本文将围绕One-fx适配方式以及在OneFlow中的应用展开。
2
FX主要模块
-
Symbolioc Trace
-
Graph Module
-
Interpreter
-
Proxy
-
Passes
其中,前4个模块共同实现了fx的基本功能,Graph Module和Proxy又是Symbolic Trace的基础,Passes则是在此基础上的扩充。
Symbolic Trace的基本概念如上图所示,最基本的模型运行过程就是从模型定义到模型执行这样一个流程。
fx则是进行了非侵入式的解析,将模型执行过程转成一张图,这张图中包含了很多个Node,每一个Node都包含了模型中的子模块或者函数调用信息,然后用户可以很方便地获取到所有的Node,并对其进行一些变换操作,最后通过GraphModule重新生成一个模型定义,并对其执行。
其中,在进行模型解析的时候,节点之间变量传递也均使用代理后的变量,如y = oneflow.relu(x),实际上x和y是Proxy(x)和Proxy(y)。
3
One-fx实现方式
这里给出一个Fx最简单的用例,以方便后续对于实现方式的介绍。
import oneflow
class MyModule(oneflow.nn.Module):
def __init__(self):
super().__init__()
self.linear = oneflow.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
y = oneflow.ones([2, 3])
x = oneflow.relu(x)
return y
m = MyModule()
traced = oneflow.fx.symbolic_trace(m)
print(traced.code)
"""
def forward(self, x):
linear = self.linear(x); x = None
relu = oneflow.relu(linear); linear = None
_tensor_constant0 = self._tensor_constant0
return _tensor_constant0
"""
函数代理
代理,即fx中的Proxy模块,目的是在每次进行函数或模块调用的时候添加一些额外操作,使得对模型的解析和重建得以进行,而包装则是适配代理的一种方式。
torch.fx中,对于nn.Module的包装比较易于理解,每当待解析Module中出现了继承自nn.Module的对象,那么就将其__call__函数替换成包装过的函数。然而,对于pytorch的函数的代理的实现要更“绕”一些,是借助了__torch_function__这一机制
(https://github.com/pytorch/pytorch/blob/c7c723897658eda6298bb74d92e4bb18ab4a5fe3/torch/overrides.py),限于篇幅原因这里不专门对其进行介绍。比较关键的点是,OneFlow中没有这一机制,如果需要添加,那么会是规模很大的、侵入性的,于是One-fx的实现就需要找其它路径。
我们使用的解决方式是搜索oneflow,oneflow.nn.functional,oneflow._C等模块中的Callable,并去除其中属于类的部分,然后对其余函数进行包装,在每次解析模型之前,会将这些模块的__dict__中对应项替换成包装后的函数,并且在解析模型之后重新将这些项进行还原。对于constructor类型的函数,如ones,randn等则不进行代理,直接运行,在最终构建图的时候作为constant来处理。
对于函数的包装部分源码实现如下,每次运行代理后的函数,会先判断该函数的入参中有没有Proxy变量,如果有,那么将会创建一个call_function类型的节点并返回Proxy包装后的节点,否则直接调用原函数并返回结果。
def _create_wrapped_func(orig_fn):
@functools.wraps(orig_fn)
def wrapped(*args, **kwargs):
# 判断参数中是否存在proxy变量
proxy = _find_proxy(args, kwargs)
if proxy is not None:
# 如果参数中有Proxy变量,创建节点并返回Proxy包装后的节点
return_proxy = proxy.tracer.create_proxy(
"call_function", orig_fn, args, kwargs
)
return_proxy.node.meta["is_wrapped"] = True
return return_proxy
# 如果没有Proxy变量,直接调用原函数
return orig_fn(*args, **kwargs)
return wrapped
其中,return_proxy = proxy.tracer.create_proxy("call_function", orig_fn, args, kwargs)这行代码指定了使用与入参相同的Tracer来创建节点并返回结果,create_proxy函数定义的主要部分如下,创建节点并在Proxy包装后返回。
def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
name: Optional[str] = None, type_expr : Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
args_ = self.cre