介绍
cublasGemmEx 是CUDA8.0中cuBLAS新出的函数,是cublasgemm()类函数的扩展,也是目前来看功能最强大的矩阵乘函数了。该函数另一强大之处在于支持多种计算模式(compute type),其中就包括CUDA 8.0新出的FP16和INT8。但是该函数的文档并不太健全,最近在使用这个函数实现INT8矩阵乘的时候就碰见坑了,照着文档用就是报错,找NVIDIA的工程师才给解决。下面总结一下使用经验,把坑填上,以防大家再踩。
函数原型
cublasStatus_t cublasGemmEx(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType_t Atype,
int lda,
const void *B,
cudaDataType_t Btype,
int ldb,
const void *beta,
void *C,
cudaDataType_t Ctype,
int ldc,
cudaDataType_t computeType,
cublasGemmAlgo_t algo)
跟cublasSgemm长的比较像,但是多了这么几个参数,Atype,Btype,Ctype,computeType和algo。
这个函数的核心就是计算模式(computeType),computeType支持以下类型:
computeType | 解释 |
---|---|
CUDA_R_16F | FP16计算模式,输入输出都是FP16 |
CUDA_R_32F | FP32计算模式,这个比较强大,输入可以是FP16、INT8和FP32 |
CUDA_R_32I | INT8计算模式,也是本文着重要讲的模式 |
CUDA_R_64F | FP64计算模式 |
CUDA_C_32F | |
CUDA_C_64F |
每个computeType支持的输入类型和输出类型在cublasGemmEx文档中写的非常清楚,照着用就行了。但是,有一个隐含的坑就在CUDA_R_32I计算模式里。
正常按照 char *A, char *B, int *C是会报错CUBLAS_STATUS_NOT_SUPPORTED,这个错误官方的解释是“the combination of the parameters Atype, Btype and Ctype and the algorithm type, algo is not supported”,大概意思就是Atype,Btype,Ctype,和algo不匹配。但是明明是按文档上写的啊,因为错误根本不在这里。
解决办法
错误的原因是,如果要使用CUDA_R_32I计算模式,那么alpha和beta这两个参数也必须是int类型且必须是0或者1……神坑啊。
PS:CUDA_R_32I计算模式下,cublasGemmAlgo_t 参数好像也只支持前7种,这个在文档里也没说。
CUDA_R_32I与CUDA_R_32F计算对比结果
这里多说一点INT8矩阵乘计算模式吧,CUDA_R_32I计算模式里调用CUDA 8.0新出的INT8计算接口-dp4a,按照官方的理论,dp4a这个函数会将四个char组合成一个int进行乘法运算,将4次乘法和3次加法减少为一次高级指令,从而提高性能。
我的实验结果表明,CUDA_R_32I模式与CUDA_R_32F模式相比,最快能提高3.2倍(与矩阵的大小有关),同时能将数据压缩75%,这是一个非常可观的收益了。
但是FP32(float)量化成INT8(char)肯定是会有精度损失的,对INT8有兴趣的可以关注NVIDIA新出的TensorRT2.0,该库能够在一些情况下保持较高的精度实现INT8加速。
TensorRT给的资料也比较少,坑也特别多,因此我开了一个TensorRT_Tutorial,欢迎志同道合者一起参与。
