BLAS之GEMM和GEMV

BLAS简介

BLAS全称是Basic Linear Algebra Subprograms是规定了一套低级的执行常见线性代数操作的规范。其实现经常针对特殊的机器进行优化,比较著名的·BLAS库有ACML, ATLAS, MKL, OpenBLAS。许多常见的数值软件均采用兼容BLAS规范的实现库来进行线性代数计算,比如Matlab, Numpy, Mathematica`。

其中,Level 1 BLAS主要提供向量操作


Level 2 BLAS提供矩阵向量操作(gemv)

gemv


Level 3 BLAS则提供广义矩阵乘积操作(gemm)

gemm


GEMM在深度学习中是十分重要的,全连接层以及卷积层基本上都是通过GEMM来实现的,而网络中大约90%的运算都是在这两层中。而一个良好的GEMM的实现可以充分利用系统的多级存储结构程序执行的局部性来充分加速运算。

 

gemm接口解析

gemm的函数接口如下图所示,darknet中也采用了类似的接口设计。

sgemm
其中,A,B,C分别是MxK, KxN, MxN的矩阵,TRANSA, TRANSB, TRANSC表示是否使用对应矩阵的转置,ALPHA, BETA为对应的系数。LDA, LDB, LDC表示对应矩阵的leading dimension,即第一维度的大小。根据我的理解(结合darknet的源码),是因为在内存中是连续存放的,而这个leading dimension的量是用来定义元素的位置的,即add(A[i, j])=A+i*lda+j。其中sgemm中的s表示是单精度的运算,类似的,还有dgemm

gemm代码分析:


源码

 *
 *  -- Reference BLAS level3 routine (version 3.7.0) --
 *  -- Reference BLAS is a software package provided by Univ. of Tennessee,    --
 *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 *     December 2016
 *
 *     .. Scalar Arguments ..
       REAL ALPHA,BETA
       INTEGER K,LDA,LDB,LDC,M,N
       CHARACTER TRANSA,TRANSB
 *     ..
 *     .. Array Arguments ..
       REAL A(LDA,*),B(LDB,*),C(LDC,*)
 *     ..
 *
 *  =====================================================================
 *
 *     .. External Functions ..
       LOGICAL LSAME
       EXTERNAL lsame
 *     ..
 *     .. External Subroutines ..
       EXTERNAL xerbla
 *     ..
 *     .. Intrinsic Functions ..
       INTRINSIC max
 *     ..
 *     .. Local Scalars ..
       REAL TEMP
       INTEGER I,INFO,J,L,NCOLA,NROWA,NROWB
       LOGICAL NOTA,NOTB
 *     ..
 *     .. Parameters ..
       REAL ONE,ZERO
       parameter(one=1.0e+0,zero=0.0e+0)
 *     ..
 *
 *     Set  NOTA  and  NOTB  as  true if  A  and  B  respectively are not
 *     transposed and set  NROWA, NCOLA and  NROWB  as the number of rows
 *     and  columns of  A  and the  number of  rows  of  B  respectively.
 *
       nota = lsame(transa,'N')
       notb = lsame(transb,'N')
       IF (nota) THEN
           nrowa = m
           ncola = k
       ELSE
           nrowa = k
           ncola = m
       END IF
       IF (notb) THEN
           nrowb = k
       ELSE
           nrowb = n
       END IF
 *
 *     Test the input parameters.
 *
       info = 0
       IF ((.NOT.nota) .AND. (.NOT.lsame(transa,'C')) .AND.
      +    (.NOT.lsame(transa,'T'))) THEN
           info = 1
       ELSE IF ((.NOT.notb) .AND. (.NOT.lsame(transb,'C')) .AND.
      +         (.NOT.lsame(transb,'T'))) THEN
           info = 2
       ELSE IF (m.LT.0) THEN
           info = 3
       ELSE IF (n.LT.0) THEN
           info = 4
       ELSE IF (k.LT.0) THEN
           info = 5
       ELSE IF (lda.LT.max(1,nrowa)) THEN
           info = 8
       ELSE IF (ldb.LT.max(1,nrowb)) THEN
           info = 10
       ELSE IF (ldc.LT.max(1,m)) THEN
           info = 13
       END IF
       IF (info.NE.0) THEN
           CALL xerbla('SGEMM ',info)
           RETURN
       END IF
 *
 *     Quick return if possible.
 *
       IF ((m.EQ.0) .OR. (n.EQ.0) .OR.
      +    (((alpha.EQ.zero).OR. (k.EQ.0)).AND. (beta.EQ.one))) RETURN
 *
 *     And if  alpha.eq.zero.
 *
       IF (alpha.EQ.zero) THEN
           IF (beta.EQ.zero) THEN
               DO 20 j = 1,n
                   DO 10 i = 1,m
                       c(i,j) = zero
    10             CONTINUE
    20         CONTINUE
           ELSE
               DO 40 j = 1,n
                   DO 30 i = 1,m
                       c(i,j) = beta*c(i,j)
    30             CONTINUE
    40         CONTINUE
           END IF
           RETURN
       END IF
 *
 *     Start the operations.
 *
       IF (notb) THEN
           IF (nota) THEN
 *
 *           Form  C := alpha*A*B + beta*C.
 *
               DO 90 j = 1,n
                   IF (beta.EQ.zero) THEN
                       DO 50 i = 1,m
                           c(i,j) = zero
    50                 CONTINUE
                   ELSE IF (beta.NE.one) THEN
                       DO 60 i = 1,m
                           c(i,j) = beta*c(i,j)
    60                 CONTINUE
                   END IF
                   DO 80 l = 1,k
                       temp = alpha*b(l,j)
                       DO 70 i = 1,m
                           c(i,j) = c(i,j) + temp*a(i,l)
    70                 CONTINUE
    80             CONTINUE
    90         CONTINUE
           ELSE
 *
 *           Form  C := alpha*A**T*B + beta*C
 *
               DO 120 j = 1,n
                   DO 110 i = 1,m
                       temp = zero
                       DO 100 l = 1,k
                           temp = temp + a(l,i)*b(l,j)
   100                 CONTINUE
                       IF (beta.EQ.zero) THEN
                           c(i,j) = alpha*temp
                       ELSE
                           c(i,j) = alpha*temp + beta*c(i,j)
                       END IF
   110             CONTINUE
   120         CONTINUE
           END IF
       ELSE
           IF (nota) THEN
 *
 *           Form  C := alpha*A*B**T + beta*C
 *
               DO 170 j = 1,n
                   IF (beta.EQ.zero) THEN
                       DO 130 i = 1,m
                           c(i,j) = zero
   130                 CONTINUE
                   ELSE IF (beta.NE.one) THEN
                       DO 140 i = 1,m
                           c(i,j) = beta*c(i,j)
   140                 CONTINUE
                   END IF
                   DO 160 l = 1,k
                       temp = alpha*b(j,l)
                       DO 150 i = 1,m
                           c(i,j) = c(i,j) + temp*a(i,l)
   150                 CONTINUE
   160             CONTINUE
   170         CONTINUE
           ELSE
 *
 *           Form  C := alpha*A**T*B**T + beta*C
 *
               DO 200 j = 1,n
                   DO 190 i = 1,m
                       temp = zero
                       DO 180 l = 1,k
                           temp = temp + a(l,i)*b(j,l)
   180                 CONTINUE
                       IF (beta.EQ.zero) THEN
                           c(i,j) = alpha*temp
                       ELSE
                           c(i,j) = alpha*temp + beta*c(i,j)
                       END IF
   190             CONTINUE
   200         CONTINUE
           END IF
       END IF
 *
       RETURN
 *
 *     End of SGEMM .
 *

 SGEMV  performs one of the matrix-vector operations

    y := alpha*A*x + beta*y,   or   y := alpha*A**T*x + beta*y,

 where alpha and beta are scalars, x and y are vectors and A is an
 m by n matrix.

源码: 

 *
 *  -- Reference BLAS level2 routine (version 3.7.0) --
 *  -- Reference BLAS is a software package provided by Univ. of Tennessee,    --
 *  -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
 *     December 2016
 *
 *     .. Scalar Arguments ..
       REAL ALPHA,BETA
       INTEGER INCX,INCY,LDA,M,N
       CHARACTER TRANS
 *     ..
 *     .. Array Arguments ..
       REAL A(LDA,*),X(*),Y(*)
 *     ..
 *
 *  =====================================================================
 *
 *     .. Parameters ..
       REAL ONE,ZERO
       parameter(one=1.0e+0,zero=0.0e+0)
 *     ..
 *     .. Local Scalars ..
       REAL TEMP
       INTEGER I,INFO,IX,IY,J,JX,JY,KX,KY,LENX,LENY
 *     ..
 *     .. External Functions ..
       LOGICAL LSAME
       EXTERNAL lsame
 *     ..
 *     .. External Subroutines ..
       EXTERNAL xerbla
 *     ..
 *     .. Intrinsic Functions ..
       INTRINSIC max
 *     ..
 *
 *     Test the input parameters.
 *
       info = 0
       IF (.NOT.lsame(trans,'N') .AND. .NOT.lsame(trans,'T') .AND.
      +    .NOT.lsame(trans,'C')) THEN
           info = 1
       ELSE IF (m.LT.0) THEN
           info = 2
       ELSE IF (n.LT.0) THEN
           info = 3
       ELSE IF (lda.LT.max(1,m)) THEN
           info = 6
       ELSE IF (incx.EQ.0) THEN
           info = 8
       ELSE IF (incy.EQ.0) THEN
           info = 11
       END IF
       IF (info.NE.0) THEN
           CALL xerbla('SGEMV ',info)
           RETURN
       END IF
 *
 *     Quick return if possible.
 *
       IF ((m.EQ.0) .OR. (n.EQ.0) .OR.
      +    ((alpha.EQ.zero).AND. (beta.EQ.one))) RETURN
 *
 *     Set  LENX  and  LENY, the lengths of the vectors x and y, and set
 *     up the start points in  X  and  Y.
 *
       IF (lsame(trans,'N')) THEN
           lenx = n
           leny = m
       ELSE
           lenx = m
           leny = n
       END IF
       IF (incx.GT.0) THEN
           kx = 1
       ELSE
           kx = 1 - (lenx-1)*incx
       END IF
       IF (incy.GT.0) THEN
           ky = 1
       ELSE
           ky = 1 - (leny-1)*incy
       END IF
 *
 *     Start the operations. In this version the elements of A are
 *     accessed sequentially with one pass through A.
 *
 *     First form  y := beta*y.
 *
       IF (beta.NE.one) THEN
           IF (incy.EQ.1) THEN
               IF (beta.EQ.zero) THEN
                   DO 10 i = 1,leny
                       y(i) = zero
    10             CONTINUE
               ELSE
                   DO 20 i = 1,leny
                       y(i) = beta*y(i)
    20             CONTINUE
               END IF
           ELSE
               iy = ky
               IF (beta.EQ.zero) THEN
                   DO 30 i = 1,leny
                       y(iy) = zero
                       iy = iy + incy
    30             CONTINUE
               ELSE
                   DO 40 i = 1,leny
                       y(iy) = beta*y(iy)
                       iy = iy + incy
    40             CONTINUE
               END IF
           END IF
       END IF
       IF (alpha.EQ.zero) RETURN
       IF (lsame(trans,'N')) THEN
 *
 *        Form  y := alpha*A*x + y.
 *
           jx = kx
           IF (incy.EQ.1) THEN
               DO 60 j = 1,n
                   temp = alpha*x(jx)
                   DO 50 i = 1,m
                       y(i) = y(i) + temp*a(i,j)
    50             CONTINUE
                   jx = jx + incx
    60         CONTINUE
           ELSE
               DO 80 j = 1,n
                   temp = alpha*x(jx)
                   iy = ky
                   DO 70 i = 1,m
                       y(iy) = y(iy) + temp*a(i,j)
                       iy = iy + incy
    70             CONTINUE
                   jx = jx + incx
    80         CONTINUE
           END IF
       ELSE
 *
 *        Form  y := alpha*A**T*x + y.
 *
           jy = ky
           IF (incx.EQ.1) THEN
               DO 100 j = 1,n
                   temp = zero
                   DO 90 i = 1,m
                       temp = temp + a(i,j)*x(i)
    90             CONTINUE
                   y(jy) = y(jy) + alpha*temp
                   jy = jy + incy
   100         CONTINUE
           ELSE
               DO 120 j = 1,n
                   temp = zero
                   ix = kx
                   DO 110 i = 1,m
                       temp = temp + a(i,j)*x(ix)
                       ix = ix + incx
   110             CONTINUE
                   y(jy) = y(jy) + alpha*temp
                   jy = jy + incy
   120         CONTINUE
           END IF
       END IF
 *
       RETURN
 *
 *     End of SGEMV .
 *

 

 

 

参考文献

  1. Basic Linear Algebra Subprograms-WikiPeia
  2. Dongarra J J, Croz J D, Hammarling S, et al. A set of level 3 basic linear algebra subprograms[J]. Acm Transactions on Mathematical Software, 1990, 16(1):1-17.
  3. Why gemm is at the heart of deep learning
  4. sgemm 官方文档
  5. https://www.jianshu.com/p/1dd118f431eb
  • 4
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值