使用SIMD NEON对矩阵乘法优化计算

使用C、C++完成矩阵相乘操作:

两个矩阵A(nxk),B(kxm)相乘,得到的结果C的维度是(nxm),必须满足A的第二维等于B的第一维,最好在相乘操作之前进行校验。

void matrix_mul(uint8_t[][]& A,uint8_t[][]& B,uint8_t[][]& C,int n,int m,int k){

    /*

    A B 为输入矩阵,其中A的维度为(nxk),B的维度为 (kxm)

    C   为输出矩阵,其中C的维度为 (nxm)

    */



    for(int i = 0;i<n;i++){

        for(int j = 0; j < m;j++){

            for(int t = 0;t<k;t++){

                C[i][j] += A[i][t] * B[t][j];

            }

        }

    }

}

2、如果二维矩阵的存储格式为一维数组,例如图片像素值的BGR、YUV存储形式,都是将矩阵按照行的形式进行存储,即第二行的数据拼接在第一行后面,依次类推,假如矩阵A为

1, 2

3, 4

5, 6

则存储为 [1,2,3,4,5,6]。假如二维矩阵A的维度为 nxm,它的存储为一维数组 a ,则它的第 i 行第 j 列的值为A[i][j] = a[ i x m+j ] .

void matrix_mul(uint8_t* A,uint8_t* B,uint8_t* C,int n,int m,int k){

    /*

    A B 为输入矩阵,其中A的维度为(nxk),B的维度为 (kxm)

    C   为输出矩阵,其中C的维度为 (nxm)

    */

   for(int i = 0;i<n;i++){

        for(int j = 0; j < m;j++){

            for(int t = 0;t<k;t++){

                //C[i][j] += A[i][t] * B[t][j]; 根据这行代码和索引公式,可以替换为

                C[i*m+j] = A[i*k+t] * B[t*m+j];

            }

        }

    }

}

使用NEON对这个矩阵相乘进行计算。将c++代码进行并行计算,在ARM上提升性能。

首先我们使用4x4的矩阵为例:

对于C[i][j] += A[i][t] * B[t][j]这个代码,他就是遍历A矩阵的行和列,以及B矩阵的行和列的所有值,一个一个地相乘相加,再根据矩阵乘法的规则来得到结果,而NEON可以‘四个四个’地操作,亦可以‘八个八个’地操作,甚至可以‘十六个十六个’地操作,至于是四个还是八个、十六个就看你怎么操作了。

首先我们使用NEON读取A的第一行和B的所有行

float32x4_t A0 =  vld1q_f32(A+0*k);

float32x4_t B0 =  vld1q_f32(B+0*m);

float32x4_t B1 =  vld1q_f32(B+1*m);

float32x4_t B2 =  vld1q_f32(B+2*m);

float32x4_t B3 =  vld1q_f32(B+3*m);

解释一下,这里的A,B是以一维数组进行存储的,其中A(nxk),B(kxm) 结果是(nxm)。如同2中所示,vld1q_f32()是使用NEON的指令读取A的数值,从float32x4_t 的输出结果可以看出,我们一次性读取出了4个值,也正好是A的一行的值。这里用vld1q_f32(),而不是vld1_f32,代表使用的是128为的寄存器()。所以:

A0 = 【A(0,0),A(0,1),A(0,2),A(0,3)】

B0 = 【B(0,0),B(0,1),B(0,2),B(0,3)】

B0 = 【B(1,0),B(1,1),B(1,2),B(1,3)】

B0 = 【B(2,0),B(2,1),B(2,2),B(2,3)】

B0 = 【B(3,0),B(3,1),B(3,2),B(3,3)】

接下来,继续读取A中的第二行,然后继续和B的所有行相乘:

float32x4_t A1 =  vld1q_f32(A+1*k);

float32x4_t B0 =  vld1q_f32(B+0*m);

float32x4_t B1 =  vld1q_f32(B+1*m);

float32x4_t B2 =  vld1q_f32(B+2*m);

float32x4_t B3 =  vld1q_f32(B+3*m);

得到了结果矩阵的第二行的结果C1,继续保存到resn*m)一维数组中:

vst1q_f32(res+4, C1);

最后我们将代码整合一下:

void matrix_mul(float32_t* A,float32_t* B,float32_t* C,uint8_t n,uint8_t m,int t){  
    int A_idx;
    int B_idx;
    int C_idx;
    
    // these are the columns of a 4x4 sub matrix of A
    float32x4_t A0;
    float32x4_t A1;
    float32x4_t A2;
    float32x4_t A3;
    
    // these are the columns of a 4x4 sub matrix of B
    float32x4_t B0;
    float32x4_t B1;
    float32x4_t B2;
    float32x4_t B3;
    
    // these are the columns of a 4x4 sub matrix of C
    float32x4_t C0;
    float32x4_t C1;
    float32x4_t C2;
    float32x4_t C3;
    
    C0=vmovq_n_f32(0);
    C1=vmovq_n_f32(0);
    C2=vmovq_n_f32(0); 
    C3=vmovq_n_f32(0);
    
    A0 = vld1q_f32(A+0*4);             
    B0 = vld1q_f32(B+0*m+b_idx);
    B1 = vld1q_f32(B+1*m+b_idx);
    B2 = vld1q_f32(B+2*m+b_idx);
    B3 = vld1q_f32(B+3*m+b_idx);

    C0=vmlaq_lane_f32(C0,B0, vget_low_f32(A0), 0);
    C0=vmlaq_lane_f32(C0,B1, vget_low_f32(A0), 1);
    C0=vmlaq_lane_f32(C0,B2, vget_high_f32(A0), 0);
    C0=vmlaq_lane_f32(C0,B3, vget_high_f32(A0), 1);
                

    A1 = vld1q_f32(A+1*4);
    C1=vmlaq_lane_f32(C1,B0, vget_low_f32(A1), 0);
    C1=vmlaq_lane_f32(C1,B1, vget_low_f32(A1), 1);
    C1=vmlaq_lane_f32(C1,B2, vget_high_f32(A1), 0);
    C1=vmlaq_lane_f32(C1,B3, vget_high_f32(A1), 1);

    A2 = vld1q_f32(A+2*4);
    C2=vmlaq_lane_f32(C2,B0, vget_low_f32(A2), 0);
    C2=vmlaq_lane_f32(C2,B1, vget_low_f32(A2), 1);
    C2=vmlaq_lane_f32(C2,B2, vget_high_f32(A2), 0);
    C2=vmlaq_lane_f32(C2,B3, vget_high_f32(A2), 1);

    A3 = vld1q_f32(A+3*4);
    C3=vmlaq_lane_f32(C3,B0, vget_low_f32(A3), 0);
    C3=vmlaq_lane_f32(C3,B1, vget_low_f32(A3), 1);
    C3=vmlaq_lane_f32(C3,B2, vget_high_f32(A3), 0);
    C3=vmlaq_lane_f32(C3,B3, vget_high_f32(A3), 1);
     
            
    vst1q_f32(C+0*4, C0);
    vst1q_f32(C+1*4, C1);
    vst1q_f32(C+2*4, C2);
    vst1q_f32(C+3*4, C3);
}

int main(){

    float32_t matrixA[16] = {1, 2, 3, 4, 5, 6, 7, 8,
        9, 10, 11, 12, 13, 14, 15, 16};

    float32_t matrixB[16] = {
        1, 2, 3, 4, 5, 6, 7, 8,
        9, 10, 11, 12, 13, 14, 15, 16
    };

    float32_t res[16] = {0};
    for(int i = 0;i<16;i++){
        cout << res[i] << " ";
        
    }
    cout << endl;
    matrix_mul(matrixA,matrixB,res,8,8,8);
    for(int i = 0;i<64;i++){
        cout << res[i] << " ";
        
    }
    
    // system("pause");
    return 0;
}

  • 8
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值