[Pytorch 源码阅读] —— 谈谈 dispatcher(一)

前言

这篇文章的内容主要还是基于 EdWard z. yang 的 Let’s talk about the PyTorch dispatcher 来梳理一下 Pytorch dispatcher 相关的内容学习以及源码阅读。其中会涉及到很多类和内容,也会出现很多源代码,所以文章篇幅会很长,这里我分成了两篇文章进行讲解,这篇主要进行概念性介绍,另一篇则是源码阅读。

概念介绍

dispatcher 可以理解为分发器,可以根据关于 tensor 输入的一些信息来决定要调用哪一块的程序。其主要是通过分发表(dispatch table)的形式来实现的,如下图:

分发表中包含了相关的 dispatch key 和对应的函数指针,可以看到 dispatch key 不仅有硬件后端,例如 CPU,GPU 等,还有一些更抽象的概念, 例如 autograd 和 tracing。dispatcher 的工作就是根据输入的 tensor 还有一些其他因素(参数个数,返回值类型等)去计算得到一个 dispatch key ,然后根据 dispatch table 找到对应对函数指针,一般我们也称为 kernel,然后间接跳转调用它。

概念介绍起来可能上面的一小段话就可以把 dispatcher 介绍完毕了,但是深入细节,其 dispatch key 是何种形式表现的?如何去进行 dispatch key 计算的?包括如何将它融入到 Pytorch 整个系统等问题,还是有很多内容值得一说的。

diapatch key 的表示和计算

diapatch key 的计算是通过一个 dispatch key set 的结构来实现的, dispatch key set 可以理解为一个 64bit 的数组,每一个 bit 都代表了一个 dispatch key,然后从左到右有优先级关系,同样一个算子,可能有针对不同 dispatch key 的实现,针对这些散落在 Pytorch各处的注册实现,然后也可能会将某些 key 排除的情况,这个数组通过最终将关于同一个算子的所有可用实现都结合到一起,调用其中优先级最高的 dispatch key 对应的 kernel 实现。

dispatch table 注册

下一个问题就是这些函数指针是如何出现在 dispatch table 中的呢?这个主要是在算子注册的时候一步步添加到 dispatch table 里面的,主要是使用了 C++ 代码,注册算子的整个过程官方有对应的文档 。主要算子注册有一下 3 种算子注册交互方式:

首先需要 m.def 定义一个关于算子的 schema(后面会介绍),然后 m.impl 将带有 dispatch key 信息的算子实现注册到 dispatch table 中,最后还有一个名为 m.fallback 的方式,在为所有的算子都注册上同一个 dispatch key。想象一下 dispatch table 和不同的算子间的对应关系可以表示为网格形式,当我们使用 m.impl 注册一个算子实现时就会传入相关的 函数指针 以及 diapatch key,假设这里我们用 C++ 实现了一个 cpu kernel 的 aten::mul 算子,对应到下图就是:

Pytorch 中也可以使用 ”catch-all“ 的操作来将同一个 kernel 注册到所有 dispatch key 上:

或者可以通过 fallback 为所有算子添加一个 dispatch key:

对于上面不同的注册形式,会有一个优先级的关系, 特定的实现 > catch all > fallback。

boxing 和 unboxing

更进一步,谈到函数调用,需要介绍另一个 dispatcher 中另一个重要的概念,boxing 和 unboxing,我们可以从数据结构的角度来对这个概念进行理解。在 C++ 中,我们都知道数据有不同的数据类型,int,float,double 等,包括一些类对象,这些多种多样的数据类型我们就可以理解为是一种 Unboxing 的行为,在 Pytorch 中定义了一种叫 IValue 的数据结构,它可以用来表示很多种数据类型,对外的表现就都是 IValue 这一种类型,这样将很多种元素打包,对外表现成一种就可以理解为是 boxing 的行为。那么对于 unboxing 的输入和 boxing 的输入,Pytorch 自然就要根据输入来进行不同形式的调用/转换(从 boxing 转换成 unboxing,或者从 unboxing 转换成 boxing),这部分将在下面源码部分做详细介绍。

类关系概述

在进入源码介绍之前首先概括性介绍一个将要出现的各个类的功能,和我自己整理的相互之间的关系(可能关系表示不是那么准确,欢迎拍砖交流),有一个大概的认识,然后有兴趣的读者可以进一步往下了解。

  1. IValue 类 : Pytorch 一种统一的数据表示,由此也引出了 boxing 和 unboxing 的概念。
  2. FuncitionSchema 类:算子表示类,通过 string 类型的 schema 定义,解析处算子的相关定义,包括输入个数,参数类型及返回值类型等信息。
  3. KernelFunction 类:相当于 std::function,主要是可以从函数指针、仿函数、lambda 函数来构建一个 boxed/unboxed 的 KernelFunction 对象,包括用 boxed/unboxed 不同方式来调用它们,这个类才是最后函数被调用的地方。
  4. OperatorHandle/ TypedOperatorHandle 类:提供 schema 相关的查询,包括算子 kernel 的调用接口。
  5. OpratorEntry 类:记录算子信息的类,包括 获取/注册 schema,注册/查找 kernelfunction,包括更新相关的 dispatch table 的相关接口。
  6. Dispatcher 类:组合上面的类功能,实现算子动态分发的类。

在这里插入图片描述
Dispatcher 类通过 OperatorEntry 类来进行算子的 schema 和 KernelFunction 的注册,其中就包括了更新 dispatch table 等行为。最后在 kernel 调用阶段需要依赖 KernelFunction 中的 call 函数进行最后的调用。

到这里,基本就是对 dispatcher 相关概念的一个介绍,如果仅仅只是想了解的同学,看到这里就可以了,碍于文章篇幅,涉及到源码的部分放在了另一篇文章,想要了解更详细的相关源码阅读,可以继续移步☞:

[Pytorch 源码阅读] —— 谈谈 dispatcher(二)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值