cublas 中 gemm 的用法

最近在看 llama2 的推理过程的源码,其中的矩阵乘法操作用到了 cublas 库。对其中的参数有较大的疑惑,所以打算梳理一下。

有一个很好的解释 cublas 中 gemm 的文章,链接如下有关CUBLAS中的矩阵乘法函数 - 爨爨爨好 - 博客园 (cnblogs.com)

具体的详细操作,请看给出的文章,我只是对其进行一个总结。

cublas 中矩阵乘法主要就是 14 个参数,该函数实际上是用于计算 C=\alphaAB+\betaC, 其中 A 大小默认为 m×k,B 大小默认为 k×n,C 大小默认为 m×n,\alpha 和 \beta 为标量。

cublasHandle_t handle  为调用 cuBLAS 库时的句柄

cublasOperation_t transa 是否转置矩阵 A

cublasOperation_t transb 是否转置矩阵 B

int m 矩阵A的行数,等于矩阵C的列数。

int n 矩阵B的列数,等于矩阵C的行数。

int k 矩阵A的列数,等于矩阵B的行数。

m, n, k 这三个字母的作用就是确定输入矩阵 A, B, C 中元素的个数(因为float * 这种类型中并未蕴含元素个数的信息)。因为要运行的函数是矩阵乘法,那么只需要这三个数字就能覆盖A,B,C中元素的个数。

const float *alpha 标量 α 的指针,可以是主机指针或设备指针,只需要计算矩阵乘法时命 α = 1.0f

const float *A  矩阵(数组)A 的指针,必须是设备指针

const float *B 矩阵(数组)B 的指针,必须是设备指针

float *C 矩阵(数组)C 的指针,必须是设备指针

const float *beta 标量 β 的指针,可以是主机指针或设备指针,只需要计算矩阵乘法时命 β = 0.0f

int lda 矩阵 A 的主维(leading dimension)(lda 要大于等于 m)

int ldb 矩阵 B 的主维 (ldb 要等于等于 k)

int ldc 矩阵 C 的主维 (ldc 要大于等于 m)

重点说下这个主维的作用,cublas 中是矩阵是以列优先的,而C++中是以行优先的,MATLAB 也是以列优先。

cublas 这个函数中有一个自己的 A1, B1, C1。输入的 A, B 都换先转换到 A1, B1, 然后 A1和B1做矩阵乘法得到C1,然后再将C1转换到 C中。

以C++进行举例,

大小为 m×k 的矩阵 A(行主序)会首先转换到 lda × k 的矩阵 A1,用0补充可能多余的空位。

大小为 k×n 的矩阵 B(行主序)会首先转换到 ldb × n 的矩阵 B1,用0补充可能多余的空位。

然后用 A1 乘以部分 B1 (前k行) 得到的大小为 lda×n C1(前ldc行)。

然后将列主序的C1转换到行主序的大小为 m×n的 C。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值