pytorch 调用 cublasLtMatmul 做 gemm同时加bias的代码,写的不错

		

template <typename Dtype>
void gemm_and_bias(
    bool transpose_mat1,
    bool transpose_mat2,
    int64_t m,
    int64_t n,
    int64_t k,
    at::opmath_type<Dtype> alpha_val,
    const Dtype* mat1_ptr,
    int64_t mat1_ld,
    const Dtype* mat2_ptr,
    int64_t mat2_ld,
    const Dtype* bias,
    Dtype* result_ptr,
    int64_t result_ld,
    GEMMAndBiasActivationEpilogue activation) {
  using opmath_t = at::opmath_type<Dtype>;
  opmath_t beta_val = 0; // bias is added in epilogue

  cudaDataType_t abcType = CUDA_R_32F;
  cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
  cudaDataType_t scaleType = CUDA_R_32F;
  if (std::is_same<Dtype, double>::value) {
    abcType = CUDA_R_64F;
    computeType = CUBLAS_COMPUTE_64F;
    scaleType = CUDA_R_64F;
  } else if (std::is_same<Dtype, float>::value) {
    if (at::globalContext().allowTF32CuBLAS()) {
      computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
    }
    abcType = CUDA_R_32F;
  } else if (std::is_same<Dtype, at::Half>::value) {
    abcType = CUDA_R_16F;
  } else if (std::is_same<Dtype, at::BFloat16>::value) {
    abcType = CUDA_R_16BF;
  }

  CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
  cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
  TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
      computeDesc.descriptor(),
      CUBLASLT_MATMUL_DESC_TRANSA,
      &transa,
      sizeof(transa)));
  cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
  TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
      computeDesc.descriptor(),
      CUBLASLT_MATMUL_DESC_TRANSB,
      &transb,
      sizeof(transb)));
  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
  if (activation == GEMMAndBiasActivationEpilogue::RELU) {
    epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
  } else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
#if CUDA_VERSION >= 11040
    epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
#endif
  }
  TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
      computeDesc.descriptor(),
      CUBLASLT_MATMUL_DESC_EPILOGUE,
      &epilogue,
      sizeof(epilogue)));
  TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
      computeDesc.descriptor(),
      CUBLASLT_MATMUL_DESC_BIAS_POINTER,
      &bias,
      sizeof(Dtype*)));

  CuBlasLtMatrixLayout Adesc(
      abcType, transpose_mat1 ? k : m, transpose_mat1 ? m : k, mat1_ld);
  CuBlasLtMatrixLayout Bdesc(
      abcType, transpose_mat2 ? n : k, transpose_mat2 ? k : n, mat2_ld);
  CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);

  CuBlasLtMatmulPreference preference;
  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
  // setting this to 1M.
  size_t workspaceSize = 1024 * 1024;
  TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(
      preference.descriptor(),
      CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
      &workspaceSize,
      sizeof(workspaceSize)));

  auto workspace = at::empty(
      {static_cast<int64_t>(workspaceSize)},
      at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte));

  cublasLtMatmulHeuristicResult_t heuristicResult = {};
  int returnedResult = 0;
  cublasLtHandle_t ltHandle =
      reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
  TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
      ltHandle,
      computeDesc.descriptor(),
      Adesc.descriptor(),
      Bdesc.descriptor(),
      Cdesc.descriptor(),
      Cdesc.descriptor(),
      preference.descriptor(),
      1,
      &heuristicResult,
      &returnedResult));
  if (returnedResult == 0) {
    TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
  }

  TORCH_CUDABLAS_CHECK(cublasLtMatmul(
      ltHandle,
      computeDesc.descriptor(),
      &alpha_val,
      mat1_ptr,
      Adesc.descriptor(),
      mat2_ptr,
      Bdesc.descriptor(),
      &beta_val,
      result_ptr,
      Cdesc.descriptor(),
      result_ptr,
      Cdesc.descriptor(),
      &heuristicResult.algo,
      workspace.data_ptr(),
      workspaceSize,
      at::cuda::getCurrentCUDAStream()));
}		
		

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值