NEON实现矩阵乘法加速

1、使用C、C++完成矩阵相乘操作:

两个矩阵A(nxk),B(kxm)相乘,得到的结果C的维度是(nxm),必须满足A的第二维等于B的第一维,最好在相乘操作之前进行校验。

  1. 假设两个矩阵是二维数组进行存储,则矩阵相乘的代码如下
void matrix_mul(uint8_t[][]& A,uint8_t[][]& B,uint8_t[][]& C,int n,int m,int k){
    /*
    A B 为输入矩阵,其中A的维度为(nxk),B的维度为 (kxm)
    C   为输出矩阵,其中C的维度为 (nxm)
    */

    for(int i = 0;i<n;i++){
        for(int j = 0; j < m;j++){
            for(int t = 0;t<k;t++){
                C[i][j] += A[i][t] * B[t][j];
            }
        }
    }
}
  1. 如果二维矩阵的存储格式为一维数组,例如图片像素值的BGR、YUV存储形式,都是将矩阵按照行的形式进行存储,即第二行的数据拼接在第一行后面,依次类推,假如矩阵A为
    1, 2
    3, 4
    5, 6
    则存储为 [1,2,3,4,5,6]。假如二维矩阵A的维度为 nxm,它的存储为一维数组 a ,则它的第 i 行第 j 列的值为A[i][j] = a[ i x m+j ] .
void matrix_mul(uint8_t* A,uint8_t* B,uint8_t* C,int n,int m,int k){
    /*
    A B 为输入矩阵,其中A的维度为(nxk),B的维度为 (kxm)
    C   为输出矩阵,其中C的维度为 (nxm)
    */
   for(int i = 0;i<n;i++){
        for(int j = 0; j < m;j++){
            for(int t = 0;t<k;t++){
                //C[i][j] += A[i][t] * B[t][j]; 根据这行代码和索引公式,可以替换为
                C[i*m+j] = A[i*k+t] * B[t*m+j];
            }
        }
    }
}

当然首先主要将C 初始化为 new uint8_t[n*m]啦。

3.重磅!!!使用NEON对这个矩阵相乘进行计算。最近笔者在学习NEON,将c++代码进行并行计算,在ARM上提升性能。挖个坑,后续还会更新使用NEON加速图像旋转、插值、颜色空间转换等。
首先我们使用4x4的矩阵为例来进行讲解。

  • 首先想用NEON进行矩阵操作的小伙伴需要了解什么是NEON,NEON简单来说就是并行(多路加载数据),举个例子,刚刚使用c++写的这句代码:
C[i][j] += A[i][t] * B[t][j]
  • 他就是遍历A矩阵的行和列,以及B矩阵的行和列的所有值,一个一个地相乘相加,再根据矩阵乘法的规则来得到结果
  • NEON觉得,‘一个一个’太慢了,NEON说它可以‘四个四个’地操作,亦可以‘八个八个’地操作,甚至可以‘十六个十六个’地操作,可牛逼了,至于是四个还是八个、十六个就看你怎么操作了。

在这里插入图片描述比如矩阵AXB=C,上图列出了C的第一列的计算方法:

在这里插入图片描述我们可以放心啊,对于C的第一列,它的每一个元素在计算中,每一个红框内,A的值都是同一个,而B的值是同一列的不同的值,找到了这个规律我们就可以使用NEON来进行并行读取

首先我们使用NEON读取A的第一行和B的所有行

在这里插入代码片:
float32x4_t A0 =  vld1q_f32(A+0*k);
float32x4_t B0 =  vld1q_f32(B+0*m);
float32x4_t B1 =  vld1q_f32(B+1*m);
float32x4_t B2 =  vld1q_f32(B+2*m);
float32x4_t B3 =  vld1q_f32(B+3*m);

解释一下,这里的A,B是以一维数组进行存储的,其中A(nxk),B(kxm) 结果是(nxm)。
如同2中所示,vld1q_f32()是使用NEON的指令读取A的数值,从float32x4_t 的输出结果可以看出,我们一次性读取出了4个值,也正好是A的一行的值。这里用vld1q_f32(),而不是vld1_f32,代表使用的是128为的寄存器(不理解的小伙伴需要先去看看NEON的一些基本操作和知识噢)。
所以`:

A0 = 【A(0,0),A(0,1),A(0,2),A(0,3)】
B0 = 【B(0,0),B(0,1),B(0,2),B(0,3)】
B0 = 【B(1,0),B(1,1),B(1,2),B(1,3)】
B0 = 【B(2,0),B(2,1),B(2,2),B(2,3)】
B0 = 【B(3,0),B(3,1),B(3,2),B(3,3)

然后我们要结合上面给出的矩阵C的计算公式来进行计算
我们使用B0、B1、B2、B3分别与A0的第一、二、三、四个值进行计算,就可以得到:

B0*A0[0] = 【B(0,0)*A(0,0),B(0,1)*A(0,0),B(0,2)*A(0,0),B(0,3)*A(0,0)】
B1*A0[1] = 【B(1,0)*A(0,1),B(1,1)*A(0,1),B(1,2)*A(0,1),B(1,3)*A(0,1)】
B2*A0[2] = 【B(2,0)*A(0,2),B(2,1)*A(0,2),B(2,2)*A(0,2),B(2,3)*A(0,2)】
B3*A0[3] = 【B(3,0)*A(0,3),B(3,1)*A(0,3),B(3,2)*A(0,3),B(3,3)*A(0,3)

然后我们分别把这四列记录为C0,并相加
所以C0 = 【
B(0,0)*A(0,0)+B(1,0)*A(0,1)+B(2,0)*A(0,2)B(3,0)*A(0,3),
B(0,1)*A(0,0)+B(1,1)*A(0,1)+B(2,1)*A(0,2)+B(3,1)*A(0,3),
B(0,2)*A(0,0)+B(1,2)*A(0,1)+B(2,2)*A(0,2)+B(3,2)*A(0,3),
B(0,3)*A(0,0)+B(1,3)*A(0,1)+B(2,3)*A(0,2)+B(3,3)A(0,3)

有没有发现,C的四个值,正好就是矩阵AXB的结果的前四个值呢,哈哈哈非常神奇吧。
然后我们需要把这四个结果存储到结果矩阵中,我们使用**vst1q_f32()**函数将结果保存到res(n
m)一维数组中:

vst1q_f32(res+0, C0);

OKOK,那么接着我们读取A的第二行,然后继续和B的所有行相乘:

float32x4_t A1 =  vld1q_f32(A+1*k);
float32x4_t B0 =  vld1q_f32(B+0*m);
float32x4_t B1 =  vld1q_f32(B+1*m);
float32x4_t B2 =  vld1q_f32(B+2*m);
float32x4_t B3 =  vld1q_f32(B+3*m);

得到了结果矩阵的第二行的结果C1,继续保存到res(n*m)一维数组中:

vst1q_f32(res+4, C1);

同理是C2和C3,这样整个结果就得到了,不知道解释的清不清楚呢。
最后我们再把代码给整合一下:

void matrix_mul(float32_t* A,float32_t* B,float32_t* C,uint8_t n,uint8_t m,int t){  
    int A_idx;
    int B_idx;
    int C_idx;
    
    // these are the columns of a 4x4 sub matrix of A
    float32x4_t A0;
    float32x4_t A1;
    float32x4_t A2;
    float32x4_t A3;
    
    // these are the columns of a 4x4 sub matrix of B
    float32x4_t B0;
    float32x4_t B1;
    float32x4_t B2;
    float32x4_t B3;
    
    // these are the columns of a 4x4 sub matrix of C
    float32x4_t C0;
    float32x4_t C1;
    float32x4_t C2;
    float32x4_t C3;
    
    C0=vmovq_n_f32(0);
    C1=vmovq_n_f32(0);
    C2=vmovq_n_f32(0); 
    C3=vmovq_n_f32(0);
    
    A0 = vld1q_f32(A+0*4);             
    B0 = vld1q_f32(B+0*m+b_idx);
    B1 = vld1q_f32(B+1*m+b_idx);
    B2 = vld1q_f32(B+2*m+b_idx);
    B3 = vld1q_f32(B+3*m+b_idx);

    C0=vmlaq_lane_f32(C0,B0, vget_low_f32(A0), 0);
    C0=vmlaq_lane_f32(C0,B1, vget_low_f32(A0), 1);
    C0=vmlaq_lane_f32(C0,B2, vget_high_f32(A0), 0);
    C0=vmlaq_lane_f32(C0,B3, vget_high_f32(A0), 1);
                

    A1 = vld1q_f32(A+1*4);
    C1=vmlaq_lane_f32(C1,B0, vget_low_f32(A1), 0);
    C1=vmlaq_lane_f32(C1,B1, vget_low_f32(A1), 1);
    C1=vmlaq_lane_f32(C1,B2, vget_high_f32(A1), 0);
    C1=vmlaq_lane_f32(C1,B3, vget_high_f32(A1), 1);

    A2 = vld1q_f32(A+2*4);
    C2=vmlaq_lane_f32(C2,B0, vget_low_f32(A2), 0);
    C2=vmlaq_lane_f32(C2,B1, vget_low_f32(A2), 1);
    C2=vmlaq_lane_f32(C2,B2, vget_high_f32(A2), 0);
    C2=vmlaq_lane_f32(C2,B3, vget_high_f32(A2), 1);

    A3 = vld1q_f32(A+3*4);
    C3=vmlaq_lane_f32(C3,B0, vget_low_f32(A3), 0);
    C3=vmlaq_lane_f32(C3,B1, vget_low_f32(A3), 1);
    C3=vmlaq_lane_f32(C3,B2, vget_high_f32(A3), 0);
    C3=vmlaq_lane_f32(C3,B3, vget_high_f32(A3), 1);
     
            
    vst1q_f32(C+0*4, C0);
    vst1q_f32(C+1*4, C1);
    vst1q_f32(C+2*4, C2);
    vst1q_f32(C+3*4, C3);
}

int main(){

    float32_t matrixA[16] = {1, 2, 3, 4, 5, 6, 7, 8,
        9, 10, 11, 12, 13, 14, 15, 16};

    float32_t matrixB[16] = {
        1, 2, 3, 4, 5, 6, 7, 8,
        9, 10, 11, 12, 13, 14, 15, 16
    };

    float32_t res[16] = {0};
    for(int i = 0;i<16;i++){
        cout << res[i] << " ";
        
    }
    cout << endl;
    matrix_mul(matrixA,matrixB,res,8,8,8);
    for(int i = 0;i<64;i++){
        cout << res[i] << " ";
        
    }
    
    // system("pause");
    return 0;
}
  1. 再解释一下vst1q_f32(C+0*4,C0);这个函数吧,就是说把C0的四个数字,按顺序填充到结果数组C里面,因为本来就是按照c00,c01,c02,c03,c10…的这个顺序存储的,所以没有问题。
  2. 然后C0=vmovq_n_f32(0)是一个初始化的函数
  3. vld1q_f32是一个读取数据的函数;
  4. **vmlaq_lane_f32(C1,B0,vget_low_f32(A1),0)**中的vmlaq_lane_f32是一个乘加函数,vmlaq_lane_f32(a,b,c)=a+bxc,还记得上面提到的C是要加起来的吗?
  5. 其次是解释一下**vget_low_f32和vget_high_f32()**函数:还记得我们需要我们使用B0、B1、B2、B3分别与A0的第一、二、三、四个值进行计算,然后是B0、B1、B2、B3和A0的第一、二、三、四个值和进行计算,得到结果的第二行,记得吧,但是我们使用vld1q_f32得到A0后,并不能方便的得到A0的第一个值A0[0],第二个值A0[1],毕竟这个A0不是数组,它是连续读进NEON寄存器的一串数字,所以我们需要想办法得到A0[0]、A0[1]、A0[2]、A0[3],当然我们可以记住一个结论:
假如寄存器TT中是四个数字 【a,b,c,d】
那么vget_low_f32(TT)= 【a,b】,vget_high_f32(TT) = 【c,d】
vmlaq_lane_f32(C1,B0,vget_low_f32(A1),0) 中的vget_low_f32(A1),0
就相当于 a,vget_low_f32(A1),1就相当于b
同理如果你要找到 c或d,就把vget_low_f32替换从vget_high_f32

ok 44与44的矩阵乘法我们已经写完了,接下来就要进阶到nk与km的矩阵了,其实就是分割成若干个44与44的矩阵相乘,再求和。

我们举一个8*8的矩阵为例子:

在这里插入图片描述C的第一个值就是A的第一行乘上B的第一列,假如我们把A的前四行看成是两个4*4的矩阵(第一列到第四列,第五列到第七列),而B的前四列与(第一到第四行),如下图所示
在这里插入图片描述可以发现,所以C(0,0)= A(0,0)*B(0,0)+A(0,1)*B(1,0)+A(0,2)*B(2,0)+A(0,3)*B(3,0)+A(0,4)*B(4,0)+A(0,5)*B(5,0)+A(0,6)*B(6,0)+A(0,7)*B(7,0)=

A0(0,0)*B0(0,0)+A0(0,1)*B0(1,0)+A0(0,2)*B0(2,0)+A0(0,3)*B0(3,0)
+
A1(0,0)*B1(0,0)+A1(0,1)*B1(1,0)+A1(0,2)*B1(2,0)+A1(0,3)*B1(3,0)

所以说,我们只要把大矩阵按照每4列划分一个44小矩阵,然后把B划分每4行划分为一个44小矩阵,然后不断加起来,得到的就是结果,举个例子,小矩阵A0和B0的结果是C0,然后再把A1和B1的结果加到C0上,当然这是88的矩阵,如果是1212甚至其他,也是同样的到理解,所以就列出完整的代码:

#include <iostream>

#include <arm_neon.h>

using namespace std;

/*
    A (mxt)  B(txn)  AxB=C(mxn)

    
*/

//vmlaq_lane_f32

// float32x4_t vfmaq_laneq_f32(float32x4_t C0,float32x4_t B0,float32x4_t A0,const int index){
//     return vmlaq_lane_f32(C0, B0, vdupq_n_f32(vgetq_lane_f32(A0,index)));
// }
// float32x4_t vfmaq_laneq_f32(float32x4_t C0,float32x4_t B0,float32_t t){
//     return vmlaq_f32(C0, B0, t);
// }
void matrix_mul(float32_t* A,float32_t* B,float32_t* C,uint8_t m,uint8_t t,uint8_t n){
    
    int A_idx;
    int B_idx;
    int C_idx;

    // these are the columns of a 4x4 sub matrix of A
    float32x4_t A0;
    float32x4_t A1;
    float32x4_t A2;
    float32x4_t A3;
    
    // these are the columns of a 4x4 sub matrix of B
    float32x4_t B0;
    float32x4_t B1;
    float32x4_t B2;
    float32x4_t B3;
    
    // these are the columns of a 4x4 sub matrix of C
    float32x4_t C0;
    float32x4_t C1;
    float32x4_t C2;
    float32x4_t C3;

    for(int i = 0;i<m;i+=4){
        
        for(int j = 0;j<n;j+=4){
            C0=vmovq_n_f32(0);
            C1=vmovq_n_f32(0);
            C2=vmovq_n_f32(0); 
            C3=vmovq_n_f32(0);
            for(int k = 0;k<t;k+=4){


                int a_idx = i*t+k;
                int b_idx = k*n+j;

                A0 = vld1q_f32(A+0*t+a_idx);
                B0 = vld1q_f32(B+0*n+b_idx);
                B1 = vld1q_f32(B+1*n+b_idx);
                B2 = vld1q_f32(B+2*n+b_idx);
                B3 = vld1q_f32(B+3*n+b_idx);

                

                //float32x4_t res = vmlaq_lane_f32(q0, q1, vget_low_f32(q2), 0);
                C0 = vmlaq_lane_f32(C0,B0,vget_low_f32(A0),0);
                C0 = vmlaq_lane_f32(C0,B1,vget_low_f32(A0),1);
                C0 = vmlaq_lane_f32(C0,B2,vget_high_f32(A0),0);
                C0 = vmlaq_lane_f32(C0,B3,vget_high_f32(A0),1);


                A1 = vld1q_f32(A+1*t+a_idx);
                C1 = vmlaq_lane_f32(C1,B0,vget_low_f32(A1),0);
                C1 = vmlaq_lane_f32(C1,B1,vget_low_f32(A1),1);
                C1 = vmlaq_lane_f32(C1,B2,vget_high_f32(A1),0);
                C1 = vmlaq_lane_f32(C1,B3,vget_high_f32(A1),1);

                A2 = vld1q_f32(A+2*t+a_idx);
                C2 = vmlaq_lane_f32(C2,B0,vget_low_f32(A2),0);
                C2 = vmlaq_lane_f32(C2,B1,vget_low_f32(A2),1);
                C2 = vmlaq_lane_f32(C2,B2,vget_high_f32(A2),0);
                C2 = vmlaq_lane_f32(C2,B3,vget_high_f32(A2),1);


                A3 = vld1q_f32(A+3*t+a_idx);
                C3 = vmlaq_lane_f32(C3,B0,vget_low_f32(A3),0);
                C3 = vmlaq_lane_f32(C3,B1,vget_low_f32(A3),1);
                C3 = vmlaq_lane_f32(C3,B2,vget_high_f32(A3),0);
                C3 = vmlaq_lane_f32(C3,B3,vget_high_f32(A3),1);

               }

            int c_idx = i*n+j;
            vst1q_f32(C+0*n+c_idx,C0);
            vst1q_f32(C+1*n+c_idx,C1);
            vst1q_f32(C+2*n+c_idx,C2);
            vst1q_f32(C+3*n+c_idx,C3);
        }
    }
}

int main()
    {

    

    float32_t matrixA[144] = {0};

    float32_t matrixB[12] = {0};

    for(int i=1;i<=144;i++){
        matrixA[i-1] = i;
        
    }
    for(int i=1;i<=12;i++){
        matrixB[i-1] = i;
    }
    

    
    for(int i = 0;i<144;i++){
        cout << matrixA[i] << ",";
        
    }
    cout << endl;
    for(int i = 0;i<12;i++){
        cout << matrixB[i] << ",";
        
    }
    cout << endl;
    
    float32_t res[12] = {0};

    matrix_mul(matrixA,matrixB,res,12,12,1);
    for(int i = 0;i<12;i++){
        cout << res[i] << " "; 
        
    }
    cout << endl;
    
    // system("pause");
    return 0;
    }



具体的流程体现在:

for(int i = 0;i<n;i+=4){
        for(int j = 0;j<m;j+=4){
            for(int k = 0;k<t;k+=4){
               int a_idx = i*n+k; // 还记得这个是怎么来的吗
               int b_idx = m*k+j; // c[i][j] += a[i][k]*b[k][j]
}
        C_idx = i*n+j;
        vst1q_f32(C+0*n+i*n+j, C0); // 最后的结果保存如下
        vst1q_f32(C+1*n+i*n+j, C1);
        vst1q_f32(C+2*n+i*n+j, C2);
        vst1q_f32(C+3*n+i*n+j, C3);

    }   
}

上面就是完整的矩阵相乘代码,但是有个缺陷,就是M、N、K必须是能够整除4的结果,不然就会出错,因为NEON只能最小是‘4个4个读取’,如果你的矩阵是7*7或者是别的什么的,还得考虑很复杂的边界条件。

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值