目的:
需要加速两个三维矩阵的乘积,形状形如[30, 128, 1536][^T]*[30, 128, 1536]=[30, 1536, 1536],使用openblas,需要对其中的cblas_sgemm做了解,经过一段时间测试,记录如下,备忘。
举例:
void mul(){
const int dim0 = 2;
const int dim1 = 2;
const int dim2 = 3;
const int dim3 = 4;
float A[dim0 * dim1 * dim2] = { 1, 2, 3, 6, 5, 4, 2, 2, 3, 6, 5, 6 };
float B[dim0 * dim1 * dim3] = { 1, 2, 3, 6, 5, 4, 2, 2, 3, 6, 5, 6, 1, 2, 3, 6 };
float C[dim0 * dim2 * dim3] = { 0 };
for (int i = 0; i < dim0; ++i) {
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, dim2, dim3, dim1, 1.0, &A[i * dim1 * dim2], dim2, &B[i * dim1 * dim3], dim3, 0.0, &C[i * dim2 * dim3], dim3);
}
}
result:
[[[31, 26, 15, 18],
[27, 24, 16, 22],
[23, 22, 17, 26]],
[[12, 24, 28, 48],
[11, 22, 25, 42],
[15, 30, 33, 52]]]
参数理解
void cblas_sgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST float *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST float *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
Order:输入数组相邻数据是按照行排列还是列排列;
TransA:是否对矩阵A进行转置;
TransB:是否对矩阵A进行转置;
M : 矩阵A的行,结果C的行(不论是否转置);
N : 矩阵B的列,结果C的列(不论是否转置);
K : 矩阵A的列,B的行(不论是否转置),结果C的行;
float *A: 矩阵A的首元素地址;
lda:如果A转置,则为转置后A的行数,如果A不转置,则为A的列数;
float *B: 矩阵B的首元素地址;
ldb:如果B转置,则为转置后B的行数,如果B不转置,则为B的列数;
float *C: 结果C的首元素地址;
ldc:结果C的行数;
注意:
由目前案例是行主序,所以以上结论是在行主序下测得