GEMM 优化报告
实验任务
实现矩阵乘法 C = A ∗ B C=A*B C=A∗B,其中, A A A, B B B, C C C 是 N ∗ N N*N N∗N 的单精度稠密矩阵。本实验中矩阵均为column major。
实验环境
华为鲲鹏920:aarch64架构,64核CPU,CPU最高工作频率2600MHz。
L1d cache:64KB
L1i cache:64KB
L2 cache:512KB
L3 cache:32768KB
Page size:65536Byte
测试用例
测试用例中我们选取的矩阵规模为 n ∈ { 32 ∗ k ± 1 , 32 ∣ 1 ≤ k ≤ 32 } n \in \{32*k\pm1, 32|1\le k\le 32\} n∈{ 32∗k±1,32∣1≤k≤32} 。
优化流程
Step 0: naive
最简单粗暴的算法就是先按行遍历再按列遍历,分别计算 C i j C_{ij} Cij。在编译过程中,我们设置编译器不做任何优化。
void square_gemm (int n, float* A, float* B, float* C)
{
/* For each row i of A */
for (int i = 0; i < n; ++i)
/* For each column j of B */
for (int j = 0; j < n; ++j)
{
/* Compute C(i,j) */
float cij = C[i+j*n];
for( int k = 0; k < n; k++ )
cij += A[i+k*n] * B[k+j*n];
C[i+j*n] = cij;
}
}
该程序的性能如下图所示,其平均性能为0.33Gflops/s。
Step 1: 加入编译器优化
在Step 0的基础上加上了O3优化,以及-fomit-frame-pointer -march=armv8-a -ffast-math -mtune=tsv110编译选项。在编译器对代码进行自动优化后,程序的性能有了明显提升,如下图所示,平均浮点运算速度为2.47Gflops/s。但是程序的性能不太稳定,尤其是在矩阵规模是32的倍数的时候,性能反而下降明显。
Step 2: 利用neon intrinsic
在ARM-v8中有32个128位定长寄存器,每个寄存器可以存4个单精度浮点数,支持SIMD向量化操作。利用这一特性,我们可以四个四个地计算矩阵 C C C中的元素。
#include "arm_neon.h"
#define A(i,j) a[ (j)*n + (i) ]
#define B(i,j) b[ (j)*n + (i) ]
#define C(i,j) c[ (j)*n + (i) ]
void solution_1 (int n, float* a, float* b, float* c){
int i, j;
for (j = 0; j < n; j++){
for (i = 0; i < ((n) & (~3)); i+=4){
float32x4_t buf = vld1q_f32(&C(i, j));
for (int k = 0; k < n; k++){
float32x4_t va = vld1q_f32(&A(i, k));
register float vb = B(k, j);
buf = vmlaq_n_f32(buf, va, vb);
}
vst1q_f32(&C(i, j), buf);
}
for (; i < n; i++){//deal with boundaries
register float temp = C(i, j);
for (int k = 0; k < n; k++){
temp += A(i, k) * B(k, j);
}
C(i, j) = temp;
}
}
}
加入SIMD向量化操作之后,程序的性能如下图所示,平均浮点运算速度达到3.68Gflops/s。
Step 3: 对矩阵B同时4列访问
在Step2中,矩阵 B B B中的每个元素在被load后只被使用了一次,为了提高矩阵B中元素的使用率,我们可以每次load矩阵 B B B中相邻4列的元素,进而通过对矩阵 A A A中的 4 × k 4\times k 4×k 的子矩阵和矩阵 B B B中的 k × 4 k\times 4 k×4 的子矩阵进行相乘,得到矩阵 C C C中的大小为 4 × 4 4\times 4 4×4的子矩阵。
#include "arm_neon.h"
#define A(i,j) a[ (j)*n + (i) ]
#define B(i,j) b[ (j)*n + (i) ]
#define C(i,j) c[ (j)*n + (i) ]
//computing (4xk)x(kx4) dot product
void add_dot_4x4 (int n, int k, float* a, float* b, float* c){
float *b_ptr_0, *b_ptr_1, *b_ptr_2, *b_ptr_3;
b_ptr_0 = &B(0, 0);
b_ptr_1 = &B(0, 1);
b_ptr_2 = &B(0, 2);
b_ptr_3 = &B(0, 3);
float32x4_t c_sum_0 = {
0};
float32x4_t c_sum_1 = {
0};
float32x4_t c_sum_2 = {
0};
float32x4_t c_sum_3 = {
0};
register float b_reg_0, b_reg_1, b_reg_2, b_reg_3;
for (int p = 0; p < k; p++){
float32x4_t a_reg = vld1q_f32(&A(0, p));
b_reg_0 = *(b_ptr_0++);
b_reg_1 = *(b_ptr_1++);
b_reg_2 = *(b_ptr_2++);
b_reg_3 = *(b_ptr_3++);
c_sum_0 = vmlaq_n_f32(c_sum_0, a_reg, b_reg_0);
c_sum_1 = vmlaq_n_f32(c_sum_1, a_reg, b_reg_1);
c_sum_2 = vmlaq_n_f32(c_sum_2, a_reg, b_reg_2);
c_sum_3 = vmlaq_n_f32(c_sum_3, a_reg, b_reg_3);
}
float *c_ptr = 0;
c_ptr = &C(0, 0);
float32x4_t c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_0);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(0