https://github.com/tpoisonooo/how-to-optimize-gemm/blob/master/aarch64/output_MMult_4x4_10.m解析

https://github.com/tpoisonooo/how-to-optimize-gemm/blob/master/aarch64/output_MMult_4x4_10.m 学习笔记


前言

看了半天,不好懂,记一下


零、MNK

mnk是举证乘法中的行列数
C=C+AB
C是mxn的矩阵,A是mxk的矩阵,B是kxn的矩阵

一、函数之前

#ifdef __ARM_NEON
#include <arm_neon.h>
#else
#error("arm neon not supported")
#endif

这段是进行一个检测,检测当前平台架构是否是arm架构,如果是arm架构的话,会自动生成__ARM_NEON
如果不是arm的话,报错

/* Block sizes */
#define mc 256
#define kc 128

设置分块大小

/* Create macros so that the matrices are stored in row-major order */

#define A(i, j) a[(i) * lda + (j)]
#define B(i, j) b[(i) * ldb + (j)]
#define C(i, j) c[(i) * ldc + (j)]

#define min(i, j) ((i) < (j) ? (i) : (j))

定义major row(行优先存储),因为lda()与i相乘,所以是行优先,如果与j相乘则为列优先。

二、MY_MMult、InnerKernel以及AddDot4x4

一共是做了三层函数
最外层是MY_MMult、其次是InnerKernel,最内层是AddDot4x4

最外层矩阵分块,中间层按照4x4的大小来进行矩阵相乘,最内层是矩阵乘法的实现(4x4)

void MY_MMult(int m, int n, int k, float *a, int lda, float *b, int ldb,
              float *c, int ldc) {
  int i, p, pb, ib;
  for (p = 0; p < k; p += kc) {
    pb = min(k - p, kc);
    for (i = 0; i < m; i += mc) {
      ib = min(m - i, mc);
      InnerKernel(ib, n, pb, &A(i, p), lda, &B(p, 0), ldb, &C(i, 0), ldc);
    }
  }
}

注:pb=min(k-p,kc),kc是规则,按照kc的大小来对矩阵A的列和矩阵B的行进行分块,但是实际过程中,会出现到最后分不到kc大小的块,于是需要一个min操作保证列数始终是正确的
ib=min(m-i,mc)同理,不过是对矩阵A的行进行分块。

void InnerKernel(int m, int n, int k, float *a, int lda, float *b, int ldb,
                 float *c, int ldc) {
  int i, j;

  for (j = 0; j < n; j += 4) {   /* Loop over the columns of C, unrolled by 4 */
    for (i = 0; i < m; i += 4) { /* Loop over the rows of C */
      /* Update C( i,j ), C( i,j+1 ), C( i,j+2 ), and C( i,j+3 ) in
         one routine (four inner products) */

      AddDot4x4(k, &A(i, 0), lda, &B(0, j), ldb, &C(i, j), ldc);
    }
  }
}

简单的两个循环对已分块的矩阵进行4x4的划分


void AddDot4x4(int k, float *a, int lda, float *b, int ldb, float *c, int ldc) {
  /* So, this routine computes a 4x4 block of matrix A

           C( 0, 0 ), C( 0, 1 ), C( 0, 2 ), C( 0, 3 ).
           C( 1, 0 ), C( 1, 1 ), C( 1, 2 ), C( 1, 3 ).
           C( 2, 0 ), C( 2, 1 ), C( 2, 2 ), C( 2, 3 ).
           C( 3, 0 ), C( 3, 1 ), C( 3, 2 ), C( 3, 3 ).

     Notice that this routine is called with c = C( i, j ) in the
     previous routine, so these are actually the elements

           C( i  , j ), C( i  , j+1 ), C( i  , j+2 ), C( i  , j+3 )
           C( i+1, j ), C( i+1, j+1 ), C( i+1, j+2 ), C( i+1, j+3 )
           C( i+2, j ), C( i+2, j+1 ), C( i+2, j+2 ), C( i+2, j+3 )
           C( i+3, j ), C( i+3, j+1 ), C( i+3, j+2 ), C( i+3, j+3 )

     in the original matrix C

     In this version, we use registers for elements in the current row
     of B as well */

  float
      /* Point to the current elements in the four rows of A */
      *a_0p_pntr,
      *a_1p_pntr, *a_2p_pntr, *a_3p_pntr;

  a_0p_pntr = &A(0, 0);
  a_1p_pntr = &A(1, 0);
  a_2p_pntr = &A(2, 0);
  a_3p_pntr = &A(3, 0);

  float32x4_t c_0p_sum = { 0 };
  float32x4_t c_1p_sum = { 0 };
  float32x4_t c_2p_sum = { 0 };
  float32x4_t c_3p_sum = { 0 };

  register float a_0p_reg, a_1p_reg, a_2p_reg, a_3p_reg;

  for (int p = 0; p < k; ++p) {
    float32x4_t b_reg = vld1q_f32(&B(p, 0));

    a_0p_reg = *a_0p_pntr++;
    a_1p_reg = *a_1p_pntr++;
    a_2p_reg = *a_2p_pntr++;
    a_3p_reg = *a_3p_pntr++;

    c_0p_sum = vmlaq_n_f32(c_0p_sum, b_reg, a_0p_reg);
    c_1p_sum = vmlaq_n_f32(c_1p_sum, b_reg, a_1p_reg);
    c_2p_sum = vmlaq_n_f32(c_2p_sum, b_reg, a_2p_reg);
    c_3p_sum = vmlaq_n_f32(c_3p_sum, b_reg, a_3p_reg);
  }

  float *c_pntr = 0;
  c_pntr = &C(0, 0);
  float32x4_t c_reg = vld1q_f32(c_pntr);
  c_reg = vaddq_f32(c_reg, c_0p_sum);
  vst1q_f32(c_pntr, c_reg);

  c_pntr = &C(1, 0);
  c_reg = vld1q_f32(c_pntr);
  c_reg = vaddq_f32(c_reg, c_1p_sum);
  vst1q_f32(c_pntr, c_reg);

  c_pntr = &C(2, 0);
  c_reg = vld1q_f32(c_pntr);
  c_reg = vaddq_f32(c_reg, c_2p_sum);
  vst1q_f32(c_pntr, c_reg);

  c_pntr = &C(3, 0);
  c_reg = vld1q_f32(c_pntr);
  c_reg = vaddq_f32(c_reg, c_3p_sum);
  vst1q_f32(c_pntr, c_reg);
}

这其中使用了几个neon指令

首先从float指针开始,定义了四个float指针类型,每个指针分别指向A矩阵的四行
定义了四个float32x4_t(表示一个包含四个单精度浮点数的向量,SIMD思想)类型的变量,正好也对应了4x4的矩阵运算
使用register 关键字定义了四个float型变量在cpu寄存器中

进入第一个循环,对A的列B的行进行循环,每次移动一个长度,循环计数是p
使用vld1q_f32指令将B的第p行写入到float32x4_t类型的neon寄存器中

每次循环指针后移地将A矩阵的每一行的单个元素放到register float类型的变量中
用四个float32x4_t的sum变量来接收利用vmlaq_n_f32指令的累乘结果,内层是一个saxpy修正

saxpy修正(单行代码举例):

c_0p_sum = vmlaq_n_f32(c_0p_sum, b_reg, a_0p_reg);

这其中的b_reg和c_0p_sum是一个float32x4_t类型,是一个包含4个单精度浮点数的向量,而后面的a_0p_reg是一个float型标量,即saxpy型修正(y=ax+y,a为标量,标量乘以向量),并将结果存储到c_0p_sum中。

后续的四个部分代码就是将C的四行进行更新,将刚刚算出来的sum值与原本的C对应的位置的值相加再更新:指针指向C的相关位置,取出此位置的值于c_reg,再求c_reg和sum的和,最后更新C。


总结

此止

  • 13
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值