【OpenAI Triton】理解矩阵乘法中的super-grouping

【OpenAI Triton】理解矩阵乘法中的super-grouping

前言

最近做推理加速,会涉及一些底层算子的工作,老早就听说triton写算子比较方便,最近正好有一些应用场景,就根据官方文档和大佬们的见解记录一下自己的所学所得;

参考

本文主要是记录自己在理解学习时对其中一块内容的理解,并不是做复述或翻译一遍官方文档的内容。所以阅读本文前建议先根据官方文档自己跑一遍矩阵乘法的示例,对triton的功能有个大致的理解,然后再来过其中每一行的代码;如果你对cuda等比较熟悉,看完之后可能就直接秒懂,哈哈哈

L2 Cache Optimizations

原始的实现

pid = triton.program_id(0);
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
pid_m = pid / grid_n;
pid_n = pid % grid_n;

l2 cache 优化后的实现

# Program ID
pid = tl.program_id(axis=0)
# Number of program ids along the M axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# Number of programs ids along the N axis
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# Number of programs in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# Id of the group this program is in
group_id = pid // num_pid_in_group
# Row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M
# If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *Within groups*, programs are ordered in a column-major order
# Row-id of the program in the *launch grid*
pid_m = first_pid_m + (pid % group_size_m)
# Col-id of the program in the *launch grid*
pid_n = (pid % num_pid_in_group) // group_size_m

首先讨论为何需要进行L2 Cache优化。简单来说,GPU硬件中存在寄存器、L1 Cache、L2 Cache和全局内存等结构,它们的读写效率逐级降低。

寄存器是GPU中最快速的存储器,用于存储线程的变量和计算中间结果。每个线程都有自己的一组寄存器,能够进行快速访问。然而,寄存器的数量非常有限,通常只有几十到几百个。对于计算密集型任务,如矩阵乘法,可以利用寄存器来存储临时变量和迭代计算中的中间结果,以减少对其他内存层次的访问。

L1 Cache位于GPU SM(Streaming Multiprocessor)内部,用于存储频繁访问的数据和指令。它是一个相对较小但速度较快的缓存,用于提高数据的局部性和访问效率。L1 Cache主要用于存储线程级别的数据,如线程的寄存器溢出数据、局部变量以及线程块内共享内存的数据。

L2 Cache是位于GPU SM之上的一个更大的缓存层次。它的容量通常比L1 Cache大数倍,但速度相对较慢。L2 Cache用于存储来自多个SM的数据,并提供更大的缓存容量以提高数据的局部性和复用性。L2 Cache能够减少对全局内存的访问,从而提高数据访问效率和整体性能。

回到矩阵乘法的优化,由于它是计算密集型操作,数据传输损耗对性能影响非常严重。因此,能够利用最近的数据存储器是至关重要的。

通常情况下,矩阵乘法会按照一个矩阵块的大小进行计算。在每次计算之前,所需的数据会从全局内存加载到L2 Cache中,然后在SM执行过程中直接从L2 Cache读取和写入数据。命中率指的是计算所需的数据能否直接从L2 Cache获取,高命中率意味着可以减少对全局内存的数据获取,从而避免大量的数据传输性能损耗。

Triton与cutlass或cuda编程的区别

以我目前的浅薄理解,Triton的编程模型主要集中在块(block)级别上,即用户无需过多关注块内部的线程计算过程。而Cutlass或CUDA编程往往更注重于细粒度的线程级别编程。因此,Triton在抽象层面上更高级,可以提高开发效率,但在性能和资源控制方面可能稍显不足。

理解Row-major ordering

pid = triton.program_id(0);
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
pid_m = pid / grid_n;
pid_n = pid % grid_n;

请添加图片描述

结合这段代码和这幅图,我们来分析row-major ordering的block循环逻辑。

在图中,可以看到矩阵A、B、C都是9x9的大小,但是要注意每个黄色格子代表一个block。如果我们设定一个BLOCK_SIZE_M x BLOCK_SIZE_N大小为64x64,那么矩阵A和B的大小都将是576x576。这也是之前所说的triton是基于block逻辑进行编程的。

在运行时,一个SM可能会同时计算多个block,而多个SM则可以并行计算更多的block。但是无论是哪个SM计算,它所需的矩阵数据都会优先从L2 Cache中获取。这与之前解释的L2缓存命中率密切相关。

pid = triton.program_id(0);

这里的program_id是一个非常重要的概念。我们编写的程序只确定了一个block的计算过程,而所有block的计算是由编译器来编译循环。这行代码实际上是在确定这个block在循环逻辑中的位置。其中的axis=0表示这个“循环”是一维的,即只有一层。如果还有axis=1,那就意味着还有嵌套的第二层。这些不同的block是并行执行的(不同的物理硬件)。

grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;

这两行比较好理解,就是计算出在block维度,行和列block的数量;

pid_m = pid / grid_n;
pid_n = pid % grid_n;

这两行代码是row-major ordering的核心逻辑,也是最简单的逻辑。在triton编程中,除了确定每个block内部的计算逻辑外,还可以根据pid(program_id)确定block的遍历逻辑,这是一个非常关键的概念。

根据之前的说明,这里的pid只有一维,范围是从0到80。在这个9x9的矩阵中,我们需要确定如何将0到80的序号填入其中,这就是所谓的block ordering逻辑。在这个例子中,我们按行遍历矩阵来确定pid → (pid_m, pid_n)的值,所以被称为row-major ordering(按行优先顺序)。

row-major ordering下的读写

这个官方解释得很清楚,我们以计算9个block为例来说明。在row-major ordering的模式下,对于矩阵A来说,需要读取9个block的数据;而对于矩阵B来说,需要读取81个block的数据;最后,矩阵C需要写入9个block的数据。因此,总共需要读取90个block的数据,写入9个block的数据。

Super-Grouping Ordering

请添加图片描述

看官方给的图,先说结论,同样在写入9个block的数据时,矩阵A和矩阵B都需要读取27个block的数据,总共涉及54个block的读取操作。相比于row-major ordering,这是一个显著的改进。

通常情况下,较高的L2缓存命中率通常意味着较少的读写次数,而较低的L2缓存命中率则通常伴随着更多的读写次数。

由于L2缓存是有限的,想象一下进行一次密集计算操作时,同时有大量的SM并行运行。如果存在大量的读写操作,无疑会对L2缓存的数据存储产生影响。当矩阵的规模很小,只需要一个指令就能完成所有数据的计算时,即所有的数据都能放到L2缓存中,L2缓存的影响就不明显。然而,在实际情况下,这种情况是不太可能的。

排布逻辑

如果我们能够完全理解row-major ordering的排布过程,那么其他的排布逻辑其实也就很容易理解了。这是因为它们的原理是相同的,都是通过pid(program_id)来确定(pid_m, pid_n)的值,即在一个9x9的block矩阵中按照希望的顺序填入pid序号。

例如,对于super-grouping的结构,它实际上是将一个block按照横向和纵向同时进行拓展,形成一个小矩形。这个小矩形看起来就像一个超级小组

在实际编程中,我们可以根据具体的需求和算法的特性,选择不同的排布逻辑来组织block的布局。无论是row-major ordering、column-major ordering还是super-grouping,它们的核心思想都是通过pid来确定每个block在整个block矩阵中的位置和顺序。

理解这些排布逻辑有助于我们更好地设计并行计算任务的数据布局,从而利用好计算资源,提高计算效率和性能。

接下来按行阐述其排布过程

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

引用董鑫大佬的两幅图(参考的第三个链接)
请添加图片描述
请添加图片描述
前三行代码逻辑是一致的,不再赘述;

num_pid_in_group = GROUP_SIZE_M * num_pid_n GROUP_SIZE_M是行方向的组大小,这里定义为3,即上面第一幅图的红色框框,num_pid_in_group 就是计算该组内一共有多少个block;

group_id = pid // num_pid_in_group 就是判断对于当前pid它是在哪个group;

first_pid_m = group_id * GROUP_SIZE_M 计算当前group第一个pid_m的编号,注意是pid_m,上面提到,排布逻辑其实就是将pid映射到(pid_m, pid_n)的过程;

group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 这一步是为了避免最后一个group是无法整除的,当前这个例子正好是整除的,所以看不太出来。稍微阐述一下,假如无法整除,设最后一个group只有2行,因为是按列排序,在算pid在这个group中对应的pid_m时,假如pid是30,那么其行号就应该是(30-27)%2=1;结合图2可以对比一下。

pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

这两行就是将pid映射到(pid_m,pid_n)的最终逻辑代码了;

一个例子

接上图2,我们对pid=30的block,来计算一下其对应的实际pid_m和pid_n。

pid = 30
num_pid_m = 9
num_pid_n = 9
GROUP_SIZE_M = 3
num_pid_in_group = 3 * 9 = 27 # 一组有27个pid
group_id = 30 // 27 = 1 # 在第1组
first_pid_m = 1 * 3 = 3 # 第一组第一个pid的行号为3
group_size_m = min(9 - 3, 3) = 3 # 不是最后一组也不是非整除,所以不影响

pid_m = 3 + (30 % 3) = 3 + 0 = 3 # 按列排序,所以取模group_size_m
pid_n = (30 % 27) // 3 = 3 // 3 = 1

pid -> (pid_m, pid_n) <==> 30 -> (3, 1) # 根据图2对比一下

至此讲完block的逻辑排布;后面可能还会再补充一些东西

  • 19
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值