最近在看 llama2 的推理过程的源码,其中的矩阵乘法操作用到了 cublas 库。对其中的参数有较大的疑惑,所以打算梳理一下。
有一个很好的解释 cublas 中 gemm 的文章,链接如下有关CUBLAS中的矩阵乘法函数 - 爨爨爨好 - 博客园 (cnblogs.com)
具体的详细操作,请看给出的文章,我只是对其进行一个总结。
cublas 中矩阵乘法主要就是 14 个参数,该函数实际上是用于计算 C=AB+C, 其中 A 大小默认为 m×k,B 大小默认为 k×n,C 大小默认为 m×n, 和 为标量。
cublasHandle_t handle 为调用 cuBLAS 库时的句柄
cublasOperation_t transa 是否转置矩阵 A
cublasOperation_t transb 是否转置矩阵 B
int m 矩阵A的行数,等于矩阵C的列数。
int n 矩阵B的列数,等于矩阵C的行数。
int k 矩阵A的列数,等于矩阵B的行数。
m, n, k 这三个字母的作用就是确定输入矩阵 A, B, C 中元素的个数(因为float * 这种类型中并未蕴含元素个数的信息)。因为要运行的函数是矩阵乘法,那么只需要这三个数字就能覆盖A,B,C中元素的个数。
const float *alpha 标量 α 的指针,可以是主机指针或设备指针,只需要计算矩阵乘法时命 α = 1.0f
const float *A 矩阵(数组)A 的指针,必须是设备指针
const float *B 矩阵(数组)B 的指针,必须是设备指针
float *C 矩阵(数组)C 的指针,必须是设备指针
const float *beta 标量 β 的指针,可以是主机指针或设备指针,只需要计算矩阵乘法时命 β = 0.0f
int lda 矩阵 A 的主维(leading dimension)(lda 要大于等于 m)
int ldb 矩阵 B 的主维 (ldb 要等于等于 k)
int ldc 矩阵 C 的主维 (ldc 要大于等于 m)
重点说下这个主维的作用,cublas 中是矩阵是以列优先的,而C++中是以行优先的,MATLAB 也是以列优先。
cublas 这个函数中有一个自己的 A1, B1, C1。输入的 A, B 都换先转换到 A1, B1, 然后 A1和B1做矩阵乘法得到C1,然后再将C1转换到 C中。
以C++进行举例,
大小为 m×k 的矩阵 A(行主序)会首先转换到 lda × k 的矩阵 A1,用0补充可能多余的空位。
大小为 k×n 的矩阵 B(行主序)会首先转换到 ldb × n 的矩阵 B1,用0补充可能多余的空位。
然后用 A1 乘以部分 B1 (前k行) 得到的大小为 lda×n C1(前ldc行)。
然后将列主序的C1转换到行主序的大小为 m×n的 C。