如何加速矩阵乘法

矩阵的内存格式:

矩阵A中第i行第j列可以表示为A[i][j]。但是在系统底层的内存中,只有连续的存储空间,编译器会将高级语言的二维数组的访存转换为对一维内存的访问。将二维坐标映射成一维坐标有很多方法,直接的有两种:行主序(row-major order)和列主序(column-major order)。简单地说,对于前面提到的A[i][j]元素,如果映射为一维数组,既可以用行主序的方式,如果一行的内存元素有lda个,则A[i][j]对应于A[i+j*lda]。之所以这里说内存元素有lda个而不说一行有n个元素或者说一列有n个元素。例如,对于行主列的存储方式而言,矩阵一行有n个元素,但是矩阵一行占用内存不一定是nxsizeof(Element of A)个,可能由于内存对齐的原因,一行占用了ldaxsizeof(Element of A)个内存,且lda>n. 因此,表示矩阵存储时,需要知道矩阵的存储格式是行主序还是列主序。一版二维数组使用的是行主序的存储方式。

本文和相关代码使用列主序的方式表示矩阵元素,即矩阵A的第i行第j列元素Aij对应为A[i+j*lda].另外在代码是中为了方便,很多地方lda=n,因此也可以直接表示成A[i+j*n]。

测试方法:

围绕计算C=C+AB展开,其中A为M行K列矩阵,B为K行N列矩阵,C为M行N列矩阵,测试中的矩阵元素类型为单精度float.

本文测试中的CPU使用的是arm架构的鲲鹏920,主频2.6GHz,具有128bit SIMD寄存器与向量指令集。每核有私有512KB L2 Cache。

本文性能测试的数据是在 M=N=K的情况下,测试N从127到1281多个数据规模的性能,评价的指标是每秒浮点运算次数Flops.

使用OpenBLAS 作为优化目标。

朴素实现

朴素矩阵乘法就是一个简单的O(N^3)算法,三层循环。对于C=C+AB这个矩阵乘法核加法,可以用下面代码描述:

for(int i=0; i<M;++i){

    for(int j=0; j < N; ++j){

       for(int k =0; k<K;++k){

        c[i+j*ldc] + =A[i+k*lda] * B[k+j*ldb];

}

}

}

直接优化,例如对C[i][j]的累加过程使用一个寄存器中的变量暂存C[i][j],编译器为了防止指针Alias产生,不能自动进行这一优化。

理论上的矩阵乘法要进行O(N^3)量级的乘加计算,内存数据量只有O(N^2)规模,内存访问的次数是O(N^3)。 GEMM是计算密集型核访存密集型。但是内存的访问带宽远远小于CPU的浮点计算吞吐量,访存延迟也远高于CPU的浮点指令延迟,而CPU的Cache虽快,大小十分有限,不能塞下所有的矩阵数据。因此,朴素实现的瓶颈会在于访存指令的延迟,每个浮点计算指令都要邓访存指令执行完成。

那么我们优化的方向是,要解决访存带宽的问题,这个问题可以通过将矩阵分块(Blocking) 的方法解决,虽然整个矩阵装不进Cache里,但是一个足够小的子矩阵还是可以的。将矩阵分块后,计算的次数和访存次数不变,但是更多的访存指令能从Cache直接获取数据,大大减小了平均访问延迟。

另外,我们还可以使用SIMD向量化指令,一个指令可以操控多个数据,可以提高访存和计算吞吐率。

数据并行化的准备:

为了之后的SIMD向量化,需要对算法的实现步骤进行调整。例如,对于矩阵C中的每一个元素的计算,朴素实现中每次只拿矩阵A的一行去点乘矩阵B的一列。而如果我们同时算C中的4个元素,那么我们就能同时取A中的连续4行与B中的连续的4列。这样连续取整块的内存是有利于后面的向量化的。

具体而言,我们每次同时计算C的4x4的16个元素,这样的话我们同时用到了A的连续4行与B的连续4列。

内层循环的核心代码如下,注意其中的C[0][0]对应的是4x4的小方块中的C[0][0]而不是整个大矩阵中的位置

for(int k =0;k<K;++k){

c[0+0*ldc] + = A[0+k*lda]*B[k+0*ldb];

c[0+1*ldc] + = A[0+k*lda]*B[k+1*ldb];

c[0+2*ldc] + = A[0+k*lda]*B[k+2*ldb];

c[0+3*ldc] + = A[0+k*lda]*B[k+3*ldb];

c[1+0*ldc] + = A[1+k*lda]*B[k+0*ldb];

c[1+1*ldc] + = A[1+k*lda]*B[k+1*ldb];

c[1+2*ldc] + = A[1+k*lda]*B[k+2*ldb];

c[1+3*ldc] + = A[1+k*lda]*B[k+3*ldb];

c[2+0*ldc] + = A[2+k*lda]*B[k+0*ldb];

c[2+1*ldc] + = A[2+k*lda]*B[k+1*ldb];

c[2+2*ldc] + = A[2+k*lda]*B[k+2*ldb];

c[2+3*ldc] + = A[2+k*lda]*B[k+3*ldb];

c[3+0*ldc] + = A[3+k*lda]*B[k+0*ldb];

c[3+1*ldc] + = A[3+k*lda]*B[k+1*ldb];

c[3+2*ldc] + = A[3+k*lda]*B[k+2*ldb];

c[3+3*ldc] + = A[3+k*lda]*B[k+3*ldb];

}

这一部分我写的可能不是很详细,大家可以参考Home · flame/how-to-optimize-gemm Wiki · GitHub 的做法。

利用寄存器减少访存次数

上面打代码有一些明显的可以减少访存次数的优化方法,例如将16个C中的元素都先用寄存器暂存,然后累加时使用寄存器,最后再写入内存。

另外,内层循环的每一次迭代中,A中的访问内存的次数为16次,B也为16次,但其实A与B都只各自访问了4个元素,因此其实也可以各用4个寄存器先加载内存,然后使用寄存器计算。

减少访存后,平均浮点性能为5.385 Gflops,这相对于之前是一个很大的提升。、

内层循环的核心代码如下:

register float c00 = 0, c01 = 0, c02 = 0, c03 = 0, c10 = 0, c11 = 0, c12 = 0, c13 = 0, c20 = 0, c21 = 0, c22 = 0, c23 = 0, c30 = 0, c31 = 0, c32 = 0, c33 = 0;

 register float a0i, a1i, a2i, a3i;
 register float b0i, b1i, b2i, b3i;
float *bi0_p, *bi1_p, *bi2_p, *bi3_p;

bi0_p =B;bi1_p = B+1*ldb; bi2_p = B+2*ldb; bi3_p = B + 3*ldb;

for(int i =0; i<n; ++i){

a0i = A[i * lda]; a1i = A[1 + i * lda]; a2i = A[2 + i * lda]; a3i = A[3 + i * lda];

   bi0 = *bi0_p++; bi1 = *bi1_p++; bi2 = *bi2_p++; bi3 = *bi3_p++;

c00 + = a0i*b0i;

c01 +=a0i*bi1;

c02 +=a0i*bi2;

c03 +=a0i*bi3;

c10 + = a0i*b0i;

c11 +=a0i*bi1;

c12 +=a0i*bi2;

c13 +=a0i*bi3;

c20 + = a0i*b0i;

c21 +=a0i*bi1;

c22 +=a0i*bi2;

c23 +=a0i*bi3;

c30 + = a0i*b0i;

c31 +=a0i*bi1;

c32 +=a0i*bi2;

c33 +=a0i*bi3;

}

c[0+0*ldc] +=c00;c[0+1*ldc] +=c01;c[0+2*ldc]+=c02;c[0+3*ldc]+=c03;

 C[1 + 0 * ldc] += c10;  C[1 + 1 * ldc] += c11;  C[1 + 2 * ldc] += c12;  C[1 + 3 * ldc] += c13;
 C[2 + 0 * ldc] += c20;  C[2 + 1 * ldc] += c21;  C[2 + 2 * ldc] += c22;  C[2 + 3 * ldc] += c23;
 C[3 + 0 * ldc] += c30;  C[3 + 1 * ldc] += c31;  C[3 + 2 * ldc] += c32;  C[3 + 3 * ldc] += c33;

SIMD 向量化

具体到arm架构的CPU上,我们使用Neon向量指令集,每个指令可以操控128bit即4个float数据。

float32x4_t c_c0, c_c1, c_c2, c_c3, a_ri, b_vi0, b_vi1, b_vi2, b_vi3;
 c_c0 = vmovq_n_f32(0.0), c_c1 = vmovq_n_f32(0.0), c_c2 = vmovq_n_f32(0.0), c_c3 = vmovq_n_f32(0.0);
 float *bi0_p, *bi1_p, *bi2_p, *bi3_p;

 bi0_p = B; bi1_p = B + 1 * ldb; bi2_p = B + 2 * ldb; bi3_p = B + 3 * ldb;

for(int i=0;i<n;++i){

   a_ri = vld1q_f32(A + i * lda);
   b_vi0 = vld1q_dup_f32(bi0_p++); b_vi1 = vld1q_dup_f32(bi1_p++); 
   b_vi2 = vld1q_dup_f32(bi2_p++); b_vi3 = vld1q_dup_f32(bi3_p++);

   c_c0 = vmlaq_f32(c_c0, a_ri, b_vi0);
   c_c1 = vmlaq_f32(c_c1, a_ri, b_vi1);
   c_c2 = vmlaq_f32(c_c2, a_ri, b_vi2);
   c_c3 = vmlaq_f32(c_c3, a_ri, b_vi3);

}

 vst1q_f32(C + 0 * ldc, c_c0); vst1q_f32(C + 1 * ldc, c_c1);
 vst1q_f32(C + 2 * ldc, c_c2); vst1q_f32(C + 3 * ldc, c_c3);

矩阵分块Blocking

之前的方法在数据规模变大后性能会有较大的下降,问题在于L2 Cache大小有限,如果不停地访存很快会把L2 Cache刷满一遍。因此应该让高密度计算中的访存范围集中,使用Blocked 分块的方法。

采取的分块策略是A每次访问MCxKC的块,B每次访问KCxldb的块。其中MC和KC的大小进行多次尝试决定。

for(int k =0; k<n;k+kc){

int k = min(n-k,KC);

for(int i=0;i<n;i+=MC){

int M = min(n-i,MC);

int N = n;

  do_block(M, N, K, n, n, n, A + i + k * n, B + k, C + i);
    }
}

内存重排Packing

每次对A进行4x1的访存时,是对4行同一列元素进行的访存,但是每次循环就会跳到下一列,由于是列主序的矩阵,循环间访存不是连续的,最好将所有循环访问的元素排到一起,可以减少跨区域访存的次数,增加数据在一条Cache Line中的概率。因此对A进行了4行元素的内存重排。

每次对B是进行1x4的访存,每次对同一行4列的元素访问,由于列主序,一次循环内的访存是不连续的,可以将4列的元素重排到一起。

进一步提高计算/访存比

之前代码的kernel部分,循环内一次计算取A的四个float,B的四个float,但是,B的每个float是用来标量乘A向量的,因此当时的做法是把B的每个float重复为1个32bitx4的向量再与A的向量相乘。

a_ri = vld1q_f32(A);
b_vi0 = vld1q_dup_f32(B); 
b_vi1 = vld1q_dup_f32(B + 1);
b_vi2 = vld1q_dup_f32(B + 2); 
b_vi3 = vld1q_dup_f32(B + 3);
c_c0 = vmlaq_f32(c_c0, a_ri, b_vi0);
c_c1 = vmlaq_f32(c_c1, a_ri, b_vi1);
c_c2 = vmlaq_f32(c_c2, a_ri, b_vi2);
c_c3 = vmlaq_f32(c_c3, a_ri, b_vi3);

这段代码有5次访存,4次fma(乘加指令)向量乘(mla指令和fma的效果是一样的)。

a_ri = vld1q_f32(A + i * lda);
b_vi0 = vld1q_f32(B);
c_c0 = vfmaq_laneq_f32(c_c0, a_ri, b_vi0, 0);
c_c1 = vfmaq_laneq_f32(c_c1, a_ri, b_vi0, 1);
c_c2 = vfmaq_laneq_f32(c_c2, a_ri, b_vi0, 2);
c_c3 = vfmaq_laneq_f32(c_c3, a_ri, b_vi0, 3);

重新排布指令,结合乘加

上面的代码中的求和部分有一些是不必要的,只要写成部分和,循环结束后再规约求和,另外还有一些求和可以利用fma指令在计算乘法时顺便求出,因此我们重新排布了指令如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
a_c0_i0 = vld1q_f32(A + i * 4);
a_c0_i1 = vld1q_f32(A + i * 4 + 4);
a_c0_i2 = vld1q_f32(A + i * 4 + 8);
a_c0_i3 = vld1q_f32(A + i * 4 + 12);
b_vi0_0 = vld1q_f32(B + 0);
b_vi1_0 = vld1q_f32(B + 4);
b_vi2_0 = vld1q_f32(B + 8);
b_vi3_0 = vld1q_f32(B + 12);
temp_v0 = vfmaq_laneq_f32(c_c0, a_c0_i0, b_vi0_0, 0);
c_c0 = vfmaq_laneq_f32(temp_v0, a_c0_i1, b_vi1_0, 0);
temp_v3 = vfmaq_laneq_f32(c_c1, a_c0_i1, b_vi1_0, 1);
c_c1 = vfmaq_laneq_f32(temp_v3, a_c0_i0, b_vi0_0, 1);
temp_v0 = vfmaq_laneq_f32(c_c2, a_c0_i0, b_vi0_0, 2);
c_c2 = vfmaq_laneq_f32(temp_v0, a_c0_i1, b_vi1_0, 2);
temp_v3 = vfmaq_laneq_f32(c_c3, a_c0_i1, b_vi1_0, 3);
c_c3 = vfmaq_laneq_f32(temp_v3, a_c0_i0, b_vi0_0, 3);


temp_v0 = vfmaq_laneq_f32(part1_c0, a_c0_i2, b_vi2_0, 0);
part1_c0 = vfmaq_laneq_f32(temp_v0, a_c0_i3, b_vi3_0, 0);
temp_v3 = vfmaq_laneq_f32(part1_c1, a_c0_i3, b_vi3_0, 1);
part1_c1 = vfmaq_laneq_f32(temp_v3, a_c0_i2, b_vi2_0, 1);
temp_v0 = vfmaq_laneq_f32(part1_c2, a_c0_i2, b_vi2_0, 2);
part1_c2 = vfmaq_laneq_f32(temp_v0, a_c0_i3, b_vi3_0, 2);
temp_v3 = vfmaq_laneq_f32(part1_c3, a_c0_i3, b_vi3_0, 3);
part1_c3 = vfmaq_laneq_f32(temp_v3, a_c0_i2, b_vi2_0, 3);

生成的汇编如下,浮点相关的全为乘加指令,4个ldp指令load 8个寄存器,16个fma指令,计算访存比为2.0

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
4011e0:	ad405854 	ldp	q20, q22, [x2]
4011e4:	ad405cb5 	ldp	q21, q23, [x5]
4011e8:	ad414850 	ldp	q16, q18, [x2, #32]
4011ec:	91010042 	add	x2, x2, #0x40
4011f0:	ad414cb1 	ldp	q17, q19, [x5, #32]
4011f4:	910100a5 	add	x5, x5, #0x40
4011f8:	4f951287 	fmla	v7.4s, v20.4s, v21.s[0]
4011fc:	eb0100bf 	cmp	x5, x1
401200:	4f951a83 	fmla	v3.4s, v20.4s, v21.s[2]
401204:	4fb712c5 	fmla	v5.4s, v22.4s, v23.s[1]
401208:	4fb71ac1 	fmla	v1.4s, v22.4s, v23.s[3]
40120c:	4f911206 	fmla	v6.4s, v16.4s, v17.s[0]
401210:	4f911a02 	fmla	v2.4s, v16.4s, v17.s[2]
401214:	4fb31244 	fmla	v4.4s, v18.4s, v19.s[1]
401218:	4fb31a40 	fmla	v0.4s, v18.4s, v19.s[3]
40121c:	4f9712c7 	fmla	v7.4s, v22.4s, v23.s[0]
401220:	4f971ac3 	fmla	v3.4s, v22.4s, v23.s[2]
401224:	4fb51285 	fmla	v5.4s, v20.4s, v21.s[1]
401228:	4fb51a81 	fmla	v1.4s, v20.4s, v21.s[3]
40122c:	4f931246 	fmla	v6.4s, v18.4s, v19.s[0]
401230:	4f931a42 	fmla	v2.4s, v18.4s, v19.s[2]
401234:	4fb11204 	fmla	v4.4s, v16.4s, v17.s[1]
401238:	4fb11a00 	fmla	v0.4s, v16.4s, v17.s[3]

平均性能为25.952 Gflops,开了-O3优化后平均性能为26.124Gflops,峰值性能为27.8 Gflops

再增加循环展开层数

之前一次加载16个A中的元素,如果一次加载32个A中的float,进行8x4的一次load,则性能还能有微弱提升

平均性能为26.22 Gflops,峰值性能为28.3 Gflops

其实接下来还有更多优化方法,例如Prefetch技术,在本轮迭代的计算中就先fetch下一轮迭代会用到的数据,充分利用流水线隐藏延迟,能够进一步提高性能。但由于这些技术高度和具体CPU架构相关且测试较为复杂, 因此我们就不在此讨论。

关于研究矩阵乘法优化的思考

我们在这里探求单线程的矩阵乘法有没有什么用呢?其实对于绝大部分的人都是用不到的,而且我们研究了许久也连行业最先进水平也没赶上。在实际应用时,绝大部分有矩阵计算需求的人应该都会调用其他的科学计算框架。那么,我们为什么还要在这里研究矩阵乘法的优化呢?

从功利的角度回答,研究矩阵乘法优化的过程中,我们对于如何利用计算机的体系结构特点来进行优化进行了深入的学习和实践,这种经验在其他需要优化的场景中可能是有用的。另外,如果有一天我们需要在一个比较新的计算硬件上用到矩阵乘法,而常用的科学计算库还没有对该硬件进行针对性优化的话,我们或许能够亲自动手优化它。

从非功利的角度回答,just for fun!


  1. https://en.wikipedia.org/wiki/Row-_and_column-major_order↩︎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值