定义Core ATen Opset
TL;DR
Meta 内部 PyTorch Core、PyTorch Edge 和 PyTorch Compiler 团队的人员合作审查了常用 ATen 运算符列表,并讨论了是否应将每个运算符添加到core ATen operator set
中,或者由core ATen decomposition table
进行分解。
我们的目标是为 ATen 库定义一个满足以下标准的core ATen operator set
:
core ATen operator set
可用作参考,了解哪些 ATen ops应由使用 PT2 导出的模型的后端或编译器处理。core ATen operator set
所隐含的分解对绝大多数用例都很有用- 绝大多数用例都不想分解
core ATen operator set
中包含的运算符
设计这组核心运算符的目的是帮助 PyTorch 传达一组稳定的运算符,开发人员可以期望他们的模型生成这些运算符,从而将必须在自定义运行时中实现或由自定义编译器后端处理的运算符数量限制在可管理的数量内。它还有助于更容易地与其他的 ML 框架(例如 MLIR、XLA 和 ONNX)集成。
背景:为什么需要核心运算符集?
随着 PyTorch 生态系统的发展,将 PyTorch 模型转换为可在特定环境中高效运行的专用表示的需求也日益增加。目前的具体示例是 TorchInductor 和 Executorch;两者都使用模型的相同 FX Graph 表示,但最终会生成不同的程序以在各自不同的运行时中执行模型。随着更多后端的开发,PyTorch 定义核心运算符集变得至关重要。这一举措也是其他的 ML 框架(例如 MLIR/XLA 和 ONNX)的常见要求,以便于与 PyTorch 更顺畅地集成。
ATen 库中注册了 3000 多个运算符;后端设计人员担心运算符数量变得非常庞大。使问题更加严重的是,许多运算符彼此之间是冗余的,例如一个运算符是另一个运算符的轻微变体(例如就地变体、输出变体)。但是,通过定义核心运算符集,PyTorch 能够传达一组稳定的运算符,开发人员可以期望他们的模型生成这些运算符,从而将必须在自定义运行时中实现或由自定义编译器后端处理的运算符数量限制在可管理的数量内。
Defining the Core Operator Set
核心 ATen 运算符集可以解释成:通过分解运算符从而达到减少注册到 ATen 的所有运算符集的目的。“分解”运算符涉及将其表示为其他运算符的组合;此类分解目前在 decomposition.py 中定义。在导出过程中,使用默认的分解列表;这称为核心 ATen 分解表。因此,核心 ATen 运算符集可以解释为注册到 ATen 的、未进一步分解的运算符列表。
一般情况下,我们定义一个“算子集”为使用特定的“分解表”进行模型导出时会产生的算子列表,因此核心 ATen 算子集就是使用核心 ATen 分解表导出模型时可以包含的算子列表。
@SherlockNoMad 之前已经开始定义核心 ATen 操作集;他确定属于核心 IR 的操作列表可在此处找到:IR — PyTorch 2.0 文档。此列表由出现在 163 个开源模型中的操作符组成,这些模型用作来自 torchbench、HuggingFace 和 TIMM 的 PT2 基准。此时,确定特定 ATen 操作符是否可以“轻松”分解为其他 ATen 操作符的一般标准。
我们现在展示的结果是 Sherlock 之前工作的延续。我们遵循相同的总体流程,手动检查一系列调查模型中出现的操作。然而,在这次迭代中,我们又设定了以下目标:
- 定义和编纂用于评估特定操作的标准,以确定它是否应该成为 ATen 核心操作符集的一部分
- 开发一个民主化的过程,其中对这项工作感兴趣的 PyTorch 中的各种团体(即 Inductor、Edge、Compiler)可以提供有关核心操作符集中应该/不应该包含的内容的意见,并且讨论和结果对更广泛的 PyTorch 社区是透明的
- 描述随着时间的推移发展此操作符集的过程;这涉及向核心集添加新操作符,以及使现有的核心操作符集适应功能模式的变化和添加到 ATen 的新操作符
我们的最终目标是开发一个稳定的核心 ATen 运算符集,以实现以下目标:
- 核心 ATen 运算符集可用作参考,了解哪些 ATen 操作应由使用 PT2 导出的模型的后端或编译器处理。
- 核心 ATen 运算符集所隐含的分解对绝大多数用例都很有用
- 绝大多数用例都不想分解核心 ATen 运算符集中的运算符
核心操作符集表示我们已明确决定不由核心 ATen 分解表分解的所有 ATen 操作符。有些操作符未被核心分解表分解,但也不属于核心 ATen 操作符集;这意味着这些操作符尚未被评估或尚未对这些操作符做出决定。
请注意,我们的目的并不是让用户局限于使用核心 ATen 分解表;后端可以根据自己的意愿自由添加或删除分解。核心运算符集力求成为不同用例和上下文的共同点,但我们鼓励后端进一步微调分解表,从而微调生成的运算符集,以实现其特定目标。
Results
来自 Meta 内部 PyTorch Core、PyTorch Edge 和 PyTorch Compiler 的人员齐聚一堂,共同审查了常用 ATen 运算符列表,并讨论了是否应将每个运算符添加到核心 ATen 运算符集,或由核心 ATen 分解表进行分解。
所考虑的运算符列表是通过提取在 pytorch-jit-paritybench 中测试的大约 10,000 个 nn.Modules 使用的运算符获得的,pytorch-jit-paritybench 是“用于在从热门 GitHub 项目中抓取的许多 nn.Module 上测量 TorchScript 与 PyTorch 奇偶校验的测试套件”。我们的想法是,通过查看模型中明确使用的运算符,我们可以定位影响最大的 ATen 运算符。
我们的决策结果总结如下。
添加到核心 ATen 运算符集的运算符
对于下面列出的运算符,[core aten] 将 ops 添加到核心 aten 集 by angelayi · 拉取请求 #107766 · pytorch/pytorch · GitHub 11 在 native_functions.yaml 中为每个运算符添加了“core”标签。由于 IRs — PyTorch 2.0 文档 180 是通过在 native_functions.yaml 中搜索带有“core”标签的运算符生成的,因此这些运算符最终也会反映在网页中。