撰文|李响
1、前言
深度学习框架中模型的运行方式主要有动态图和静态图两种,动态图更易用,静态图性能更具优势,OneFlow 习惯将它们称为 Eager 模式和 Graph 模式。
OneFlow 提供了 nn.Graph 模块,让用户可以用类似 Eager 模式的编程习惯,构建静态图训练测试。因此,需要同时保证 Eager 和 Graph 模式下算子行为和结果的正确性。
在之前的文章《深度学习框架如何优雅地做算子对齐任务》中 ,分析了 Eager Ops 的自动测试流程,包括如何产生随机数据测试用例和 AutoTest 核心代码实现,AutoTest 框架可以很轻易移植到其它深度学习框架使用。
不过,本文的主要目的则是介绍 OneFlow 如何完成 Graph 模式下算子的测试任务。目前为止,OneFlow v0.7.0 已经新增所有 Op 在 nn.Graph
上做静态执行的单测支持,自动化单测功能完备。
文章中涉及到的代码位置:
-
https://github.com/Oneflow-Inc/oneflow/blob/master/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py
-
https://github.com/Oneflow-Inc/oneflow/blob/master/python/oneflow/test_utils/automated_test_util/generators.py
2、OneFlow 的 Graph 算子对齐概述
OneFlow 提供的 Eager 模式,用法与 PyTorch 对齐。所以在测试上,AutoTest 框架会随机出各种合法参数组合成的 Op ,并基于数值和类型完全相同的输入 Tensor(PyTorch 和 OneFlow 各有一份)分别运行 PyTorch 和 OneFlow 的代码,来完成算子对齐工作。
此外,OneFlow 还提供了 Graph 模式,基于面向对象的编程风格,让熟悉 Eager 开发的用户,只需改很少的代码,就可以高效使用静态图。对比 Eager 模式,Graph 模式不易调试,但性能更好,易于优化和部署。那么,如何自动测试 Graph 模式下的 Ops 就是重点需要关注的问题。
在详细介绍 Graph 单测之前,我们先看一下 AutoTest 框架里 Graph 打开方法,下面是一个测试 matmul 算子的例子。基于 random_pytorch_tensor
方法构造了两个随机的 tensor,它们的维度分别是 [n, k]
和 [k, m]
,这些维度的值都是随机生成的,AutoTest 框架参数的随机性都是基于 generators.py 中的 generator
基类完成的。
@autotest(check_graph = True)
def test_flow_matmul_with_random_data(test_case):
device = random_device()
k = random(1, 6)
x = random_tensor(ndim=2, dim1=k).to(device)
y = random_tensor(ndim=2, dim0=k).to(device)
z = torch.matmul(x, y)
return z
通过调用 torch.matmul(x, y)
,自动测试框架会分别运行 Torch 和 OneFlow 的 matmul 算子,会检查 Eager 模式下 OneFlow 和 PyTorch 算子的前向和反向结果是否一致。值得注意的是,代码中 @autotest
装饰器的 check_graph
开关为 True
,表示此时会并行地做 Graph 的单测。
3、Graph 模式下自动测试实现原理
在了解背景和使用方法后,这里介绍 Graph AutoTest 的实现思路。
3.1 AutoTest 流程介绍
在 Eager 的自动测试原理中,关于随机数据是如何产生的和 autotest()
装饰器的实现,在前文中有清晰的介绍。关于 AutoTest 框架核心流程实现,首先必须要关注用于 OneFlow 和 Pytorch 的算子对齐任务中的 GetDualObject
函数。
GetDualObject
函数会重写传入的原始 PyTorch
以及 OneFlow
对象的 __call__
魔法函数,最后返回一个 DualObject
对象。这个过程中还包含跳过一些不需要关注的魔法函数,检查传入对象的属性是否合法,基于 nn.Module
和其它 API 默认参数的类型对 generator
继承类产生的随机数据绑定特定类型的工作( get_args
函数中完成)。此外,在代码中还有对 tensor
方法的特判,因为 tensor
方法的调用方式(通过 getattr
)和 nn.module
、 nn.functional
不同(通过 __call__
)。
基于上述流程,通过执行样例代码中的 t