C++模板推导再炫技:统一深度学习框架各个device各个kernel的调用和分发

本文探讨了如何通过C++模板技术实现深度学习框架中对不同设备(如CPU、GPU)的GEMM操作的自动化分发,使用模板特化和Blas类封装底层库(如Cublas),以提供统一的接口供上层调用,从而简化设备间的代码管理。
摘要由CSDN通过智能技术生成

最近迷恋上了模板,看了很多模板推导,模板偏特化,模板特化,变长参数模板的例子,之前也发过一些模板的文章,比如这篇paddle的,本文会比paddle的这篇略简单一点,另外,准备后面几期文章再带来几个C++模板与深度学习应用结合的case。原文位于我的公众号“AI不止算法”文章链接在此

背景和问题

cpu和gpu的gemm实现肯定是不同的,以调库为例,一个是可以调cblas,一个是可以调cublas甚至cutlass,因为cblas和cublas的参数比较类似,所以本文选择cublas作为例子来讲。

那么问题来了,成熟的深度学习框架或者深度学习推理引擎都支持好几个后端device,比如IntelCPU armCPU AMDCPU NVGPU AMDGPU NPU XPU FPGA…etc, 那么我该如何实现一套优雅的自动化机制或接口,使得上层指定一个类型为某某device,就可以自动分发到指定device的不同gemm实现,而不是hard code直接调用某个device的gemm实现,如果把不同device的gemm实现比作人类的各个部位,那么上层指定device的输入就可比作大脑的指令,这套自动化机制和接口就是血液,优雅的流到不同的部位。

方法

废话不多说,直接show code。

我的最上层接口要调用matmul:

template <typename T, typename Device>
void MatmulKernel(const Device& dev,
        ...) {
  //初始化名为blas的这个handle,本质就是初始化了一个叫做Blas的类
  auto blas = GetBlas<Device, T>(dev);
  // handle调用matmul,完事
  blas.MatMul(args...);
}

这就是最上层的调用接口,只需要指定计算类型和Device,那就会自动分发到底层若干device的若干类型的gemm实现

接下来一步步揭开第5行和第7行的神秘面纱

第5行:不出所料,其实就是个构造函数的调用

template <typename Device, typename T>
inline Blas<Device, T> GetBlas(const Device& dev) {
  return Blas<Device, T>(dev);
}

第7行:Blas这个类包含了如下GEMM和Matmul等等一系列(重载)成员函数,用于handle不同输入参数的矩阵乘法需求。可以看出,Blas这个类其实就是一个wrapper类,负责承上启下的模板推导,推导到不同device的gemm实现

template <typename Device, typename T>
class Blas {
 public:
  explicit Blas(const Device& dev) : dev_(dev) {}

  void GEMM(bool transA,
            bool transB,
            int M,
            int N,
            int K,
            T alpha,
            const T* A,
            int lda,
            const T* B,
            int ldb,
            T beta,
            T* C,
            int ldc) const;

  void MatMul(Tensor& mat_a,
              bool trans_a,
              Tensor& mat_b,
              bool trans_b,
              T alpha,
              Tensor* mat_out,
              T beta) const;
   ......
}

那么我们接下来就可以上不同device的gemm具体实现,本质上就是模板偏特化和模板特例化,这是模板世界的IfThenElse,此处只给了CPU和GPU的特化,如果还有其它设备那么接着针对此设备特化实现即可

template <>
template <typename T>
void Blas<CPU, T>::GEMM(CBLAS_TRANSPOSE transA,
                                 CBLAS_TRANSPOSE transB,
                                 int M,
                                 int N,
                                 int K,
                                 T alpha,
                                 const T *A,
                                 const T *B,
                                 T beta,
                                 T *C) const {
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  CBlas<T>::GEMM(CblasRowMajor,
                 transA,
                 transB,
                 M,
                 N,
                 K,
                 alpha,
                 A,
                 lda,
                 B,
                 ldb,
                 beta,
                 C,
                 ldc);
}

template <>
template <typename T>
void Blas<GPU, T>::GEMM(bool transA,
                                 bool transB,
                                 int M,
                                 int N,
                                 int K,
                                 T alpha,
                                 const T *A,
                                 int lda,
                                 const T *B,
                                 int ldb,
                                 T beta,
                                 T *C,
                                 int ldc) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;

  CUBlas<T>::GEMM_EX(&cuda_ctx,
                       cuTransB,
                       cuTransA,
                       N,
                       M,
                       K,
                       &alpha,
                       B,
                       CUDA_R_32F,
                       ldb,
                       A,
                       CUDA_R_32F,
                       lda,
                       &beta,
                       C,
                       CUDA_R_32F,
                       ldc);

番外篇1

在现有的大模型推理引擎里面,包括FT,lmdeploy,llama.cpp等等,说实话都是hard code到不同的device或kernel实现,并没有一个统一的接口来发牌,包括我的课程三里面,也是如此,还没有实现这一套模板推导机制,我计划后面把这样一个模板推导的优雅思想加入自己课程三中。

番外篇2

NV现有的把C++模板推导运用到极致的案例就是cutlass,它支持多个GPU架构多个数据类型多个计算单元多条计算指令多个epilogues的实现,假如我手里只有V100,我怎么利用cutlass调到v100 1st gen tensorcore来做HGEMM,假如我手里有A100, 我怎么利用cutlass调到A100 3rd gen tensorcore来做HGEMM,并且融合bias和relu这俩epilogues,假如我手里有H100, 我怎么利用cutlass调到H100 4rd gen tensorcore来做FP8 GEMM,并且采用splitK的K维度处理策略。**以上cutlass是怎么分发的?**当然是模板推导!!yep!!

有一个类叫做DefaultGemm,这个就是个典型的负责模板推导的类,类似于上文的Blas,推导到各个case的偏特化/特例化实现。

比如下面这个针对volta gpu作的偏特化实现

/// Partial specialization for Volta architecture
template <
  /// Element type for A matrix operand
  typename ElementA,
  /// Layout type for A matrix operand
  typename LayoutA,
  /// Access granularity of A matrix in units of elements
  int kAlignmentA,
  /// Element type for B matrix operand
  typename ElementB,
  /// Layout type for B matrix operand
  typename LayoutB,
  /// Access granularity of B matrix in units of elements
  int kAlignmentB,
  /// Element type for C and D matrix operands
  typename ElementC,
  /// Element type for internal accumulation
  typename ElementAccumulator,
  /// Threadblock-level tile size (concept: GemmShape)
  typename ThreadblockShape,
  /// Warp-level tile size (concept: GemmShape)
  typename WarpShape,
  /// Epilogue output operator
  typename EpilogueOutputOp,
  /// Threadblock-level swizzling operator
  typename ThreadblockSwizzle,
  /// If true, kernel is configured to support serial reduction in the epilogue
  bool SplitKSerial,
  /// Operation performed by GEMM
  typename Operator,
  /// Use zfill or predicate for out-of-bound cp.async
  SharedMemoryClearOption SharedMemoryClear,
  /// Gather operand A by using an index array
  bool GatherA,
  /// Gather operand B by using an index array
  bool GatherB,
  /// Scatter result D by using an index array
  bool ScatterD
>
struct DefaultGemm<
  ElementA, LayoutA, kAlignmentA,
  ElementB, LayoutB, kAlignmentB,
  ElementC, layout::RowMajor,
  ElementAccumulator,
  arch::OpClassTensorOp,
  arch::Sm70,
  ThreadblockShape,
  WarpShape,
  GemmShape<8, 8, 4>,
  EpilogueOutputOp,
  ThreadblockSwizzle,
  2,
  SplitKSerial,
  Operator,
  SharedMemoryClear,
  GatherA,
  GatherB,
  ScatterD
>

最后,感谢读者点赞和看一看,欢迎关注我的公众号“AI不止算法”。

  • 25
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值