24_gemm_grouped

思想

最朴素的思想是一个gemm起一个kernel,然后循环就可以n次就行,但是有些gemm尺寸比较小,这样一个wave下sm无法全部利用,为此用多stream来启动,这样就可以保证多个kernel可以一块执行,但是效果不是很好,可能是因为多次启动的原因。为此需要继续优化思路,这里利用 persistent kernel的思想去处理,简单的来说就是kernel不切出去,一直循环加载待处理问题。
在这里插入图片描述
上面这幅图简单的介绍了思想,这里要注意,绿色块的gemm是指结果矩阵,假如每个sm里面可以放得下2个block, 那就是6个常驻block,这里假如有3个gemm,那么每个block需要负责的区域如上图。这有个隐藏的问题是某个gemm的k极大,那么负责那个gemm的几个block就会一直在那计算,但是其他block都已经完成工作了,所以就需要提前进行排序,把耗时多的tile尽量都放在一个wave里面完成。

数据初始化

int64_t total_elements_A = 0;
std::vector<int64_t> lda_host;//存每个problem的数据个数,存在host里,后面拷贝到device的lda
lda_host.resize(problem_count());
    for (int32_t i = 0; i < problem_count(); ++i) {
      auto problem = options.problem_sizes.at(i);
      lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0);
      offset_A.push_back(total_elements_A);
      int64_t elements_A = problem.m() * problem.k();
      total_elements_A += elements_A;
	  lda.reset(problem_count());
	  block_A.reset(total_elements_A);// device 一共需要多少A数据量,同时cudamalloc
    }

综上在device上:
lda=[size_0, size_1, …, size_14 ];
problem_sizes_device=[[m0,n0,k0], [m1, n1, k1],…[m14,n14,k14]]
block_A=[all value]
ptr_A=[ptr_0, ptr_1,…ptr_14]

  using ElementA = cutlass::half_t;
  using ElementB = cutlass::half_t;
  using ElementOutput = cutlass::half_t;// C和D都是half_t, 但是A*B的中间值用float
  using ElementAccumulator = float;

  using LayoutA = cutlass::layout::ColumnMajor;
  using LayoutB = cutlass::layout::ColumnMajor;
  using LayoutC = cutlass::layout::ColumnMajor;
  using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
    ElementA,
    LayoutA,
    cutlass::ComplexTransform::kNone,
    8,//因为是half,所以单条指令访问128bit/16bit=8
    ElementB,
    LayoutB,
    cutlass::ComplexTransform::kNone,
    8,
    ElementOutput, LayoutC,
    ElementAccumulator,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<128, 128, 32>,
    cutlass::gemm::GemmShape<64, 64, 32>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    cutlass::epilogue::thread::LinearCombination<
        ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
        ElementAccumulator, ElementAccumulator>,
    // NOTE: Threadblock 
    cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
    4>::GemmKernel;

  using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;

计算

这个就是和之前的gemm没啥区别了,重要的是problem_visitor的实现,一种方式在cpu上算好放在share memory中,一种是在gpu上运行,我比较推荐在cpu上,逻辑简单清楚一些,做计算的时候只要关注gemm就行。

总结

这里简单记录一下group的思想,在hopper架构上,为了隐藏prelogue 和epilogue的延迟也是这个思想写的,这里可以先热身了解一下

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值