【Cute】流水线代码理解
导读
https://zhuanlan.zhihu.com/p/665082713
无流水线下的矩阵乘法
int num_tile_k = size<2>(gA);
#pragma unroll 1
for(int itile = 0; itile < num_tile_k; ++itile) {
cute::copy(tAgA(_, _, _, itile), tArA);
cute::copy(tBgB(_, _, _, itile), tBrB);
cute::gemm(tiled_mma, tCrC, tArA, tBrB, tCrC);
}
上述代码是slice-k模式的乘法,gA的shape是(MMA_M, MMA_K,num_tile_k),gB的shape是(MMA_N,MMA_K,num_tile_k),最后输出的tCrC为在寄存器上的MMA_M,MMA_N,slice-k就是在warp上循环num_tile_k进行计算。
流水线
pipeline的概念在导读的参考中讲的很不错,建议先进行阅读。本文很多图片引用自它。
流水线的核心逻辑就是将非数据依赖的加载和计算环节解耦开来,并行进行,预取数据。
流水线主要有两部分:
-
tile块间的pipeline,这个操作的思想就是在计算当前tile块时,提前预取下一个(或者多个,看stage数量,需要成倍的smem空间)tile块的数据。这样就不需要在计算完当前tile后才开始下一个tile的读取(这里就会产生串行的等待开销,pipeline主要就是为了缓解这个)。
-
tile块内的pipeline,在上述代码中,直接cute::gemm就完成了整个tile块的计算,这是cutlass提前定义好的(不过暂且不清楚它的预定义是否有做优化)。tiledMMA定义时是有布局的(value layout或新的Permutations描述),即在M和N维度上进行了atom重复,这个重复不会增加线程,只是循环,tile块内的pipeline就是旨在做atom粒度的数据预取(只要一个操作涉及到读取与计算操作,并且其是loop逻辑的,那么它就有多级并行的潜在可能性)。
执行逻辑
reference:https://github.com/reed-lau/cute-gemm/blob/main/gemm-multi-stage.cu#L102-L168
绿色块的含义是tile块的读取,三块即有3个stage,smem大小是one stage的三倍。如果把剩下的看成一个整体,那么它们两个构成tile块间的pipeline;这一步是global mem到shared mem的拷贝
黄色块的含义就是tile内的小k块(一个atom指令大小)数据读取,深绿块的含义就是对寄存器内的数据进行mma的计算,也是在atom级别;这一步是完成shared mem到register mem的拷贝,并对register内的数据进行计算。
主要逻辑就是判断数据读取任务何时发射,发射哪一块,什么时候产生数据依赖,需要做等待同步。
以下就针对代码做逐行注释;
int itile_to_read = 0; // 当前在读第几个大tile块,一共就是上面提到的num_tile_k块
int ismem_read = 0; // 当前在读的第几个大tile块的smem,这个是共享内存的,需要往register中写
int ismem_write = 0; // 这个对应itile_to_read,从global mem写到smem的定位
// 这一段是先发射kStage - 1个大tile块的数据读取任务,只留下一个stage空位用以后面的数据循环
#pragma unroll
for (int istage = 0; istage < kStage - 1; ++istage) {
cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, istage),
tAsA_copy(_, _, _, istage)); // gmem -> shm
cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, istage),
tBsB_copy(_, _, _, istage)); // gmem -> shm
cp_async_fence();
++itile_to_read;
++ismem_write;
}
// wait one submitted gmem->smem done
cp_async_wait<kStage - 2>(); // 这个函数的含义是最多允许还有kStage-2个任务没有完成
// 之前发射了kStage-1个任务,那么这里的含义就是最多还允许有1个任务未完成
__syncthreads();
// 先把第一个小k传到register,方便后面的循环,这步的逻辑就是对应上面图中的第一个红色块
int ik = 0;
// smem -> reg
cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik, ismem_read), tCrA_view(_, _, ik));
cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik, ismem_read), tCrB_view(_, _, ik));
// loop over k: i. load tile, ii. mma
int ntile = k / kTileK; // 这个和size<2>(tAgA_copy)是等价的
#pragma unroll 1
for (int itile = 0; itile < ntile; ++itile) { // 进入块间流水线循环
int nk = size<2>(tCrA); // 这个和num_tile_k是类似的逻辑,只是这边是针对atom粒度的,num_tile_k是针对tiledMMA粒度的
#pragma unroll
for (int ik = 0; ik < nk; ++ik) { // 进入块内流水线循环,这边看代码也是串行的,但是是在register内部,可能没有数据依赖会自动并行掉
int ik_next = (ik + 1) % nk; // ik_next是用来做数据预取的
if (ik == nk - 1) { // ik是用来做数据计算,如果已经是当前tile的最后一块,则要提前判断让下一个大tile块数据从global mem加载到smem完成
cp_async_wait<kStage - 2>();
__syncthreads();
ismem_read = (ismem_read + 1) % kStage; // 这边是一个取模,当前所在计算的数据在stage总空间中是循环轮转的
}
// shm -> reg s[itile][ik + 1] -> r[ik + 1]
// 此处就是提前读取下一个小k块所需的数据,ismem_read就是当前读取所在的smem的stage位置,是轮转的
cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik_next, ismem_read),
tCrA_view(_, _, ik_next));
cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik_next, ismem_read),
tCrB_view(_, _, ik_next));
if (ik == 0) { // 当前计算的块是第一块时,需要去发射下一个大tile的数据预取任务,下一块和当前块是没有数据依赖的。
if (itile_to_read < ntile) { // 判断是否已经是最后一块,不是的话,拷贝下一个大tile的数据从gmem到smem
cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, itile_to_read),
tAsA_copy(_, _, _, ismem_write));
cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, itile_to_read),
tBsB_copy(_, _, _, ismem_write));
++itile_to_read; // tile块id加1
ismem_write = (ismem_write + 1) % kStage; // ismem_write是取模stage,它是轮转的
}
cp_async_fence();
}
cute::gemm(tiled_mma, tCrD, tCrA(_, _, ik), tCrB(_, _, ik), tCrD);
} // for ik
} // itile
结合图示看完注释肯定就能理解了,因为参考的代码是可以直接run的,建议手动跑一下;
官方也有一份代码示例:https://github.com/NVIDIA/cutlass/blob/ffa34e70756b0bc744e1dfcc115b5a991a68f132/include/cutlass/gemm/collective/sm80_mma_multistage.hpp,但没找到单元测试的入口,可以看一下代码