GEMM inTriton (Split-K and Stream-K)

Triton是OpenAI的开源项目。官网https://openai.com/index/triton/。Github地址https://github.com/triton-lang/triton。自问世来,一直以来都受到业界关注,而且近年来热度似乎有了明显提升。可以看到将Triton用于LLM的例子越来越多。各种流行的LLM框架,如vLLM,SGLang和TRT-LLM中也都有了Triton的身影。PyTorch也对它进行了官方支持。在PyTorch中Triton可用于自定义算子的开发并方便地与torch.compile集成(https://pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)

它主要解决的是在并行加速芯片上写高性能算子的问题。像CUDA这样的编程接口易学难精。写个能工作的实现和写出具有SOTA性能的kernel所需的专业时间和工程精力差别巨大。Triton的定位就是以极小的工程代价能达到手写算子约八成的性能。正如同Python与C++,或者C++与汇编的关系。将来可能大多算子会用Triton开发,只有在那些性能瓶颈的算子才会用CUDA去开发。

对于大多数使用者而言,更关心的是如何使用Triton。一个官方hello world可见:https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html。从中可以大概看到用Triton写一个kernel的基本范式与套路。Triton的DSL中有program_id的概念,对应CUDA中的CTA,也就是thread block。使用中很多时候以block为单位,这样就可以尽量少地纠缠于warp, thread等更细节的概念。

本文以最常见的计算GEMM为例,看下用Triton是如何实现它以及它的几种变体(Split-K,Stream-K)的。

Classic GEMM

作为引子,首先看下经典的用tiling来做GEMM是如何在Triton中实现的。基本的写法可参见:https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html。Kernel的CTA的个数就是输出矩阵中的tile数量。也就是说,每个CTA计算一个输出矩阵中的tile。它需要循环多次进行累加。循环次数为k维上的block数。

与CUDA有所不同的是,像这里的offs_amoffs_bn等描述的都是一个range,即下标数组。如果BLOCK_SIZE_M不能整除M的话,余数部分会从0开始。这一部分是冗余计算,最后会用mask过滤掉。整体代码比较易懂,不需要过多解释。这里稍微绕一些的可能用于L2 optimization的CTA id重映射。

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

经过这个变换(M维度上的分块)后,按CTA id递增tile的顺序变成:
在这里插入图片描述

Split-K GEMM

源码可参考https://github.com/triton-lang/triton/blob/v2.1.0/python/triton/ops/matmul.py。首先是kernel启动部分:

# launch kernel
grid = lambda META: ( 
    cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]),
    META["SPLIT_K"],
)
_kernel[grid](

这里grid为二维,第一个维度为output矩阵的tile数,即M维的block数 x N维的block数,第二个维度为Split-K中的K,即K维上分成几个partition。这些partition由不同的CTA计算并累加。

接下来就是kernel的定义:

@autotune(                                                                
    configs=[                                                             
        # basic configs for compute-bound matmuls                         
        Config(                                                           
            {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
            num_stages=3,                                                 
            num_warps=8,                                                  
        ),                                                                
        ...
}
@heuristics(                                                                         
    {                                                                                
        "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, 
    }                                                                                
)                                                                                    
@jit                                                                                 
def _kernel(...)                                                                       

这里由@jit修饰的就是kernel函数的定义了。Kernel函数的参数分成几类:

  • 调用者给的,如输入输出,维度信息这些。
  • 可tuning参数,如block大小。它们与性能相关。可参见:https://triton-lang.org/main/python-api/generated/triton.autotune.html。
  • 基于heuristics得到的参数,基于预定义规则得到。如EVEN_K代表K维上的元素能否被CTA平分。可参见:https://triton-lang.org/main/python-api/generated/triton.heuristics.html。

上面这个kernel实现可分为三个部分:

  1. 根据program_id确定当前要处理的数据。
# matrix multiplication                                           
pid = tl.program_id(0)                                            
pid_z = tl.program_id(1)                                          
grid_m = tl.cdiv(M, BLOCK_M)                                      
grid_n = tl.cdiv(N, BLOCK_N)                                      
# re-order program ID for better L2 performance                   
width = GROUP_M * grid_n                                          
group_id = pid // width                                           
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)            
pid_m = group_id * GROUP_M + (pid % group_size)                   
pid_n = (pid % width) // (group_size)                             
# do matrix multiplication                                        
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)                      
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)                      
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)                      
# pointers                                                        
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)      
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)      

其中以下几行:

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)

告诉编译器数据是align且连续的。这样编译器就可以做一些诸如vectorization的优化。

  1. 对于每个输出中的tile,在K维上做进行累加。
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)                                
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):                                  
    if EVEN_K:                                                                     
        a = tl.load(A)                                                             
        b = tl.load(B)                                                             
    else:                                                                          
        k_remaining = K - k * (BLOCK_K * SPLIT_K)                                  
        _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)                            
        a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)                   
        b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)                   
    if AB_DTYPE is not None:                                                       
        a = a.to(AB_DTYPE)                                                         
        b = b.to(AB_DTYPE)                                                         
    if fp8_fast_accum:                                                             
        acc = tl.dot(                                                              
            a, b, acc, out_dtype=acc_dtype, input_precision=input_precision        
        )                                                                          
    else:                                                                          
        acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision)  
    A += BLOCK_K * SPLIT_K * stride_ak                                             
    B += BLOCK_K * SPLIT_K * stride_bk                                             
acc = acc.to(C.dtype.element_ty)                                                   
  1. 写回结果。
# rematerialize rm and rn to save registers                  
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)                 
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)                 
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)  
mask = (rm < M)[:, None] & (rn < N)[None, :]                 
# handles write-back with reduction-splitting                
if SPLIT_K == 1:                                             
    tl.store(C, acc, mask=mask)                              
else:                                                        
    tl.atomic_add(C, acc, mask=mask)                         

这里首先重新计算了输出矩阵中的坐标。其实前面已经算过了,但如果保留到这里就需要耗费register。

然后用mask保证不越界。最后判断如果SPLIT_K为1,即退化为非SPLIT K的情况,那就直接存结果。否则意味着多个CTA共同计算一个output矩阵的tile,那就需要做累加。又由于这些写的线程属于不同CTA,可能并行执行,因此需要使用atomic add。

Stream-K GEMM

传统的做法是以problem为出发点,将problem size进行切割后,再将它们分到并行的计算单元上。但GPU上的并行计算处理器(SM)数量是固定的。这可能导致wave quantization问题(即子问题的数量不是并行处理器的整数),浪费计算资源。举个最简单的例子,一个任务可分为11个小任务(每个小工作需要1人/天),分给10个人干。但共需要2天才能全部完成,第二个会有9个人是无事可干的。而且这种问题随着GPU的更新,会越来越严重。因为计算单元更强,意味着需要更大的子问题才能“喂饱”它。那子问题的数量自然就会更少,这样就更容易出现wave quantization的现象。

那我们是否可以将子问题切得足够小(如果M, N维不够切就用Split-K),这样就能减少或避免wave quantization。但这样虽然SM利用率可能高了,但性能未必高。Block size过小可能会导致线程内IPL机会变小,另外计算密度变小。如果这样比较抽象的话,可以看一下具体的例子:https://pytorch.org/blog/accelerating-triton/#50-warp-stalling。

Split-K的限制在于它需要K维上的block在CTA间均分。那我们可不可以换个角度,从并行计算处理器为出发点,将子问题按SM数量来切分?这样所有的任务都可以在一个wave中完成,自然也不存在wave quantization的问题。这就是Stream-K的思想。还是用上面的例子,一个任务分为11个小的子任务,将11个子任务按人数分成10份,每个人完成1.1份。Stream-K与Split-K相比,SM利用率更高。而且同步归约次数也更少(不多于SM个数)。

Stream-K算法可参见2023年的论文《Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU》的Algorithm 5。但单纯使用Stream-K可能导致tile-processing skew问题。因此实际使用时会采用称为"two-tile Stream-K + data parallel"的混合调度方法。它的基本思想是先用Stream-K将那些除不尽的可能导致wave quantization的部分在一个wave中处理掉。这样剩下的部分是能整除的,就可以用传统的高效的DP(data parallel)方法来计算了。便于理解的形象的图可参见论文中的Figure 3 (c)

接下来学习如何用Triton来实现Stream-K:https://github.com/triton-lang/triton/issues/1393。代码大体可分三步:首先根据problem计算workload在Stream-K和DP之间怎么划分。然后分别起两个kernel用Stream-K和DP来计算相应的部分。

  1. 准备阶段
M, K = a.shape
_, N = b.shape
# accumulator types
ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# compute grid (work to do per SM on the first wave)
total_blocks_M = triton.cdiv(M, BLK_M)
total_blocks_N = triton.cdiv(N, BLK_N)
iters_per_tile = triton.cdiv(K, BLK_K)
GROUP_M = 8  # 0 to disable swizzling
total_tiles = total_blocks_M * total_blocks_N

if total_programs_streamk > 0:  # Stream-K
    # last wave may occupy less than total_programs_streamk SMs
    total_tiles_streamk = total_tiles % total_programs_streamk
    # for two-tile Stream-K + data-parallel from original paper
    if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk:
        total_tiles_streamk += total_programs_streamk
    # remaining tiles are computed using classical blocking
    total_blocking_tiles = total_tiles - total_tiles_streamk
    total_iters_streamk = total_tiles_streamk * iters_per_tile
    # iterations related to full waves
    total_full_tiles_streamk = total_iters_streamk // total_programs_streamk
    # iterations related to last (partial) wave
    total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk

...

这里total_programs_streams是SM数量(同时也是CTA数量)。total_blocks_Mtotal_blocks_Niters_per_tile分别是M, N, K三个维度上的block数。total_tiles是输出矩阵上的tile数量。将这个tile数量除以SM数量,余数部分交给stream k处理。这个数量为total_tiles_streamk。另外根据paper中的two tiles策略解决SM间的workload imbalance问题,对于这部分每个SM还要再多算一个tile。如共21个tile,4个SM。余数为1,每个SM再多算一个tile,则stream k中,是4个SM计算5个tile。那剩下的tile(数量为total_blocking_tiles)是由传统DP方法来算。total_full_tiles_streamktotal_partial_tiles_streamk分别是Stream-K中每个SM处理的iter次数的最小值,以及余下的iter次数(意味着有些SM需要平摊这些余下的部分)。下面是一个小规模(5个tile,每个tile在K维上2个block)的切分示意图:
在这里插入图片描述
可以看到,除不尽的余数部分被分到了前几个SM。

  1. Stream-K

对应的kernel函数为first_wave。这里的grid size为SM数量。也就是说每个CTA对应一个SM。SM是可以并行执行的。

# allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype)
# allocates locks to sync work accross SMs
locks = torch.zeros((total_tiles_streamk,), device=device, dtype=torch.int32)
k1 = first_wave[(total_programs_streamk,)](...)

下面看下该kernel的实现:

@triton.jit()                                                                                    
def first_wave(                                                                                  
        A, B, C,                                                                                 
        M, N, K,                                                                                 
        locks,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,                        
        total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile,                   
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
        GROUP_M: tl.constexpr,
):      
    pid = tl.program_id(0)                                                                       
    start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk)   
    last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk)
    
    while start_iter < last_iter:
        end_iter = tl.minimum(start_iter + (iters_per_tile - start_iter % iters_per_tile), last_iter)
        mac_loop(A, B, C,
                 M, N, K,                                                                        
                 locks,
                 stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
                 iters_per_tile,                                                                 
                 start_iter, end_iter,                                                           
                 BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,                                            
                 GROUP_M,                                                                        
                 )
                 
        start_iter = end_iter

它对应论文的Algorithm 5。这里有点绕的地方在于一个CTA处理可能跨了两个tile,也就是说包含了两个tile的iteration。先以简单的场景(2个CTA处理3个tile,每个tile在K维上分4个block)过一下论文中的算法。
在这里插入图片描述

回到代码来。代码中的start_iterlast_iter对应论文中iteriter_end,代表当前CTA要处理的iter范围。这里在计算当前CTA的iter范围(即start_iterlast_iter)时还要考虑无法整除的情况。

由于这些iter可能跨tile,因此要调多次mac loop完成。代码中的start_iterend_iter对应论文中的 local_iterlocal_iter_end,代表当前这次mac loop所处理的iter范围。Mac loop的主体部分就是对于给定范围在k维上做reduction。代码实现中将数据的写回部分也放到mac_loop函数中了,所以与论文中Algorithm 3中的MacLoop相比,看起来要复杂一些。

结果写回部分的复杂之处在于一个tile中的iter现在可能由不同的CTA完成了。因此得由一个CTA写,其它的做原子累加。假设某个tile中的iter会由2个CTA(标为CTA0和CTA1)共同完成。

对应地,代码中,CTA 1的mac_loop中由于条件end_iter % iters_per_tile == 0成立,会将当前的partial accumulator写入到结果矩阵C中。然后通过tl.atomic_xchg通知其它CTA。CTA 0在mac_loop中会先等待CTA 1完成的信号,然后通过tl.atomic_add将partial accumulator以原子操作的方式加到结果矩阵。示意图如下:
在这里插入图片描述

  1. DP
    这部分起第二个kernel,采用经典的分块策略计算剩余部分。该策略前面已经讲过了,这里不再累述。
k2 = full_tiles[(total_blocking_tiles,)](...)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值