template <typename Dtype>
void gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
at::opmath_type<Dtype> alpha_val,
const Dtype* mat1_ptr,
int64_t mat1_ld,
const Dtype* mat2_ptr,
int64_t mat2_ld,
const Dtype* bias,
Dtype* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation) {
using opmath_t = at::opmath_type<Dtype>;
opmath_t beta_val = 0; // bias is added in epilogue
cudaDataType_t abcType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
if (std::is_same<Dtype, double>::value) {
abcType = CUDA_R_64F;
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if (std::is_same<Dtype, float>::value) {
if (at::globalContext().allowTF32CuBLAS()) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
abcType = CUDA_R_32F;
} else if (std::is_same<Dtype, at::Half>::value) {
abcType = CUDA_R_16F;
} else if (std::is_same<Dtype, at::BFloat16>::value) {
abcType = CUDA_R_16BF;
}
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_TRANSA,
&transa,
sizeof(transa)));
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_TRANSB,
&transb,
sizeof(transb)));
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
#if CUDA_VERSION >= 11040
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
#endif
}
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias,
sizeof(Dtype*)));
CuBlasLtMatrixLayout Adesc(
abcType, transpose_mat1 ? k : m, transpose_mat1 ? m : k, mat1_ld);
CuBlasLtMatrixLayout Bdesc(
abcType, transpose_mat2 ? n : k, transpose_mat2 ? k : n, mat2_ld);
CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);
CuBlasLtMatmulPreference preference;
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = 1024 * 1024;
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(
preference.descriptor(),
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize,
sizeof(workspaceSize)));
auto workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte));
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
if (returnedResult == 0) {
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
}
TORCH_CUDABLAS_CHECK(cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
&alpha_val,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
&beta_val,
result_ptr,
Cdesc.descriptor(),
result_ptr,
Cdesc.descriptor(),
&heuristicResult.algo,
workspace.data_ptr(),
workspaceSize,
at::cuda::getCurrentCUDAStream()));
}
pytorch 调用 cublasLtMatmul 做 gemm同时加bias的代码,写的不错
最新推荐文章于 2024-11-12 18:31:50 发布