【Cute】流水线代码理解

【Cute】流水线代码理解

导读

https://zhuanlan.zhihu.com/p/665082713

无流水线下的矩阵乘法

  int num_tile_k = size<2>(gA);
#pragma unroll 1
  for(int itile = 0; itile < num_tile_k; ++itile) {
    cute::copy(tAgA(_, _, _, itile), tArA);
    cute::copy(tBgB(_, _, _, itile), tBrB);

    cute::gemm(tiled_mma, tCrC, tArA, tBrB, tCrC);
  }

上述代码是slice-k模式的乘法,gA的shape是(MMA_M, MMA_K,num_tile_k),gB的shape是(MMA_N,MMA_K,num_tile_k),最后输出的tCrC为在寄存器上的MMA_M,MMA_N,slice-k就是在warp上循环num_tile_k进行计算。

流水线

pipeline的概念在导读的参考中讲的很不错,建议先进行阅读。本文很多图片引用自它。

流水线的核心逻辑就是将非数据依赖的加载和计算环节解耦开来,并行进行,预取数据。

流水线主要有两部分:

  1. tile块间的pipeline,这个操作的思想就是在计算当前tile块时,提前预取下一个(或者多个,看stage数量,需要成倍的smem空间)tile块的数据。这样就不需要在计算完当前tile后才开始下一个tile的读取(这里就会产生串行的等待开销,pipeline主要就是为了缓解这个)。
    请添加图片描述
    请添加图片描述

  2. tile块内的pipeline,在上述代码中,直接cute::gemm就完成了整个tile块的计算,这是cutlass提前定义好的(不过暂且不清楚它的预定义是否有做优化)。tiledMMA定义时是有布局的(value layout或新的Permutations描述),即在M和N维度上进行了atom重复,这个重复不会增加线程,只是循环,tile块内的pipeline就是旨在做atom粒度的数据预取(只要一个操作涉及到读取与计算操作,并且其是loop逻辑的,那么它就有多级并行的潜在可能性)。

    请添加图片描述

执行逻辑

请添加图片描述
reference:https://github.com/reed-lau/cute-gemm/blob/main/gemm-multi-stage.cu#L102-L168

绿色块的含义是tile块的读取,三块即有3个stage,smem大小是one stage的三倍。如果把剩下的看成一个整体,那么它们两个构成tile块间的pipeline;这一步是global mem到shared mem的拷贝

黄色块的含义就是tile内的小k块(一个atom指令大小)数据读取,深绿块的含义就是对寄存器内的数据进行mma的计算,也是在atom级别;这一步是完成shared mem到register mem的拷贝,并对register内的数据进行计算。

主要逻辑就是判断数据读取任务何时发射,发射哪一块,什么时候产生数据依赖,需要做等待同步。

以下就针对代码做逐行注释;

 	int itile_to_read = 0; // 当前在读第几个大tile块,一共就是上面提到的num_tile_k块
  int ismem_read = 0; // 当前在读的第几个大tile块的smem,这个是共享内存的,需要往register中写
  int ismem_write = 0; // 这个对应itile_to_read,从global mem写到smem的定位

  // 这一段是先发射kStage - 1个大tile块的数据读取任务,只留下一个stage空位用以后面的数据循环
#pragma unroll
  for (int istage = 0; istage < kStage - 1; ++istage) {
    cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, istage),
               tAsA_copy(_, _, _, istage));   // gmem -> shm
    cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, istage),
               tBsB_copy(_, _, _, istage));  // gmem -> shm
    cp_async_fence();

    ++itile_to_read;
    ++ismem_write;
  }

  // wait one submitted gmem->smem done
  cp_async_wait<kStage - 2>(); // 这个函数的含义是最多允许还有kStage-2个任务没有完成
  // 之前发射了kStage-1个任务,那么这里的含义就是最多还允许有1个任务未完成
  __syncthreads();

  // 先把第一个小k传到register,方便后面的循环,这步的逻辑就是对应上面图中的第一个红色块
  int ik = 0;
  // smem -> reg
  cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik, ismem_read), tCrA_view(_, _, ik));
  cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik, ismem_read), tCrB_view(_, _, ik));

  // loop over k: i. load tile, ii. mma
  int ntile = k / kTileK; // 这个和size<2>(tAgA_copy)是等价的
#pragma unroll 1
  for (int itile = 0; itile < ntile; ++itile) { // 进入块间流水线循环
    int nk = size<2>(tCrA); // 这个和num_tile_k是类似的逻辑,只是这边是针对atom粒度的,num_tile_k是针对tiledMMA粒度的

#pragma unroll
    for (int ik = 0; ik < nk; ++ik) { // 进入块内流水线循环,这边看代码也是串行的,但是是在register内部,可能没有数据依赖会自动并行掉
      int ik_next = (ik + 1) % nk; // ik_next是用来做数据预取的

      if (ik == nk - 1) { // ik是用来做数据计算,如果已经是当前tile的最后一块,则要提前判断让下一个大tile块数据从global mem加载到smem完成
        cp_async_wait<kStage - 2>();
        __syncthreads();

        ismem_read = (ismem_read + 1) % kStage; // 这边是一个取模,当前所在计算的数据在stage总空间中是循环轮转的
      }

      // shm -> reg s[itile][ik + 1] -> r[ik + 1]
	  // 此处就是提前读取下一个小k块所需的数据,ismem_read就是当前读取所在的smem的stage位置,是轮转的
      cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik_next, ismem_read),
                 tCrA_view(_, _, ik_next));
      cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik_next, ismem_read),
                 tCrB_view(_, _, ik_next));

      if (ik == 0) { // 当前计算的块是第一块时,需要去发射下一个大tile的数据预取任务,下一块和当前块是没有数据依赖的。
        if (itile_to_read < ntile) { // 判断是否已经是最后一块,不是的话,拷贝下一个大tile的数据从gmem到smem
          cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, itile_to_read),
                     tAsA_copy(_, _, _, ismem_write));
          cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, itile_to_read),
                     tBsB_copy(_, _, _, ismem_write));

          ++itile_to_read; // tile块id加1
          ismem_write = (ismem_write + 1) % kStage; // ismem_write是取模stage,它是轮转的
        }

        cp_async_fence();
      }

      cute::gemm(tiled_mma, tCrD, tCrA(_, _, ik), tCrB(_, _, ik), tCrD);
    }  // for ik
  }    // itile

结合图示看完注释肯定就能理解了,因为参考的代码是可以直接run的,建议手动跑一下;

官方也有一份代码示例:https://github.com/NVIDIA/cutlass/blob/ffa34e70756b0bc744e1dfcc115b5a991a68f132/include/cutlass/gemm/collective/sm80_mma_multistage.hpp,但没找到单元测试的入口,可以看一下代码

  • 9
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值