darknet源码分析(三)gemm实现

上一节中我们分析了darknet卷积层的前向操作过程,darknet对卷积计算的处理实际上是:

先使用im2col将input_channel*(height*width)的输入特征图(实际存储是按照行存储的,即是1*(input_channel*height*width)的一维数组)转化成(input_channel*kernel_size*kernel_size)*(out_height*out_width)的特征矩阵,这里同样是按行存储的。

之后通过gemm函数实现通用的矩阵乘法实现卷积计算,即让卷积核矩阵*im2col后的输入特征矩阵。其中卷积核的大小为(kenel_channel)*(inut_channel*kernel_size*kernel_size)

最后得到kernel_channel*(out_height*out_width)即卷积输出的最终结果。

这一节,我们来看非常重要的gemm。

首先来看gemm.h,看其中的gemm_cpu,实现的是C=ALPHA*A*B + BETA*C操作,这里的BETA*C表示加上偏置项。通过其传入的参数,我们会发现gemm_cpu会对通过TA和TB变量来判断是否对卷积核矩阵A和输入的经过im2col转换过的特征矩阵B进行转置操作,因此真正实现的时候会根据不同的转置方式来采取不同的gemm方法。由于我们的A与B的存储方式都是一维数组,因此输入参数还要包含A与B的行与列。gemm_cpu的输入参数解释如下

/*

**  功能:矩阵计算,完成C = ALPHA * A * B + BETA * C,

**       输出的C也是按行存储(所有行并成一行)

**  输入: A,B,C   输入矩阵(一维数组格式,按行存储,所有行并成一行)

**        ALPHA   系数

**        BETA    系数

**        TA,TB   是否需要对A,B做转置操作,是为1,否为0

**        M       A,C的行数

**        N       B,C的列数

**        K       A的列数,B的行数

**        lda     A的列数(不做转置)或者行数(做转置)

**        ldb     B的列数(不做转置)或者行数(做转置)

**        ldc     C的列数


*/

void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float BETA,
        float *C, int ldc)

现在来看gemm到底是怎么实现的吧,来到gemm.c中,查看gemm_cpu()函数

void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float BETA,
        float *C, int ldc)
{
    //printf("cpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc);
    int i, j;
	/*首先完成BETA * C的操作*/
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            C[i*ldc + j] *= BETA;
        }
    }
	/*根据指定的TA和TB来选择不同的矩阵乘法方法,如gemm_nn就代表A与B都不进行转置操作的情况,gemm_tn代表对A进行转置*/
    if(!TA && !TB)
        gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(TA && !TB)
        gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(!TA && TB)
        gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else
        gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
}

这里做一下解释:

当TA=0,TB=0时,我们进行的是C = ALPHA * A * B + BETA * C操作

当TA=1,TB=0时,进行的是C = ALPHA * A' * B + BETA * C操作

当TA=0,TB=1时,进行的是C = ALPHA * A * B' + BETA * C操作

当TA=1,TB=1时,进行的时C = ALPHA * A' * B' + BETA * C操作

也就是说我们的矩阵乘法一定要符合矩阵乘法准则

例如A = [1, 2, 3, 2, 2, 1], B = [2, 0, 1, 1, 2, 1],C=[0,0,0,0](因为是按行存储的,所以都是一维数组,这个输入是打死不变的。进行矩阵乘法时要将其想象成多维进行处理)我们最后的输出C假设是2*2的矩阵。这样进行矩阵乘法的A与B分别是2*3和3*2,这样如果使用的是gemm_nn,A矩阵的实际大小就是2*3,同理B矩阵的大小是3*2。如果使用的是gemm_tn,也就是说A矩阵经过转置变成了2*3的矩阵,这样也就是说输入的A矩阵为3*2的即[1,2;3,2;2,1],而B矩阵没有经过转置,因此B矩阵为[2,0;1,1;2,1], A'与B相乘最后算出的C为[9,5;8,3]。而当使用的是gemm_tt时,也就是A与B都进行了转置,这样A应该就是[1,2;3,2;2,1],B矩阵应为[2,0,1;1,2,1],这样C应为[4,9;5,7]。下面来分析一下gemm_nn的具体实现方法。

/*
** M:C的行数,因为这里A没有做转置换操作,因此这里A的行数是M
** N:C的列数,因为这里B也没有做转置操作,因此这里B的列数是N
** K:这里都没有转置,因此K代表A的列数,B的行数
** lda: 不转置时该变量是A的列数,因此A的列数是lda
** ldb: 不转置时该变量时B的行数,因此B的行数是ldb
** ldc: C的列数
*/
void gemm_nn(int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
    // 遍历A的每一行
    for(i = 0; i < M; ++i){
		// 遍历A的每一列
        for(k = 0; k < K; ++k){
			// 首先将A_PART*A的操作做完
            register float A_PART = ALPHA*A[i*lda+k];
			// 使用A_PART的第i行的所有数与B第k列的所有数做乘加操作
            for(j = 0; j < N; ++j){
                C[i*ldc+j] += A_PART*B[k*ldb+j];
            }
        }
    }
}

与之形成对比,下面来看gemm_tt的操作过程

/*
** M:C的行数,因为这里A做转置换操作,因此这里A的列数是M
** N:C的列数,因为这里B做转置操作,因此这里B的行数是N
** K:这里都转置,因此K代表A的行数,B的列数
** lda: 不转置时该变量是A的列数,因此A的行数是lda
** ldb: 不转置时该变量时B的行数,因此B的列数是ldb
** ldc: C的列数
*/
void gemm_tt(int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
	// 这里遍历的是A的列数,C的行数
    for(i = 0; i < M; ++i){
		// 遍历的是C的列数,B的行数
        for(j = 0; j < N; ++j){
            register float sum = 0;
			// 遍历A的行和B的列
            for(k = 0; k < K; ++k){
                sum += ALPHA*A[i+k*lda]*B[k+j*ldb];
            }
            C[i*ldc+j] += sum;
        }
    }
}

可以看到这几个gemm之间的区别就在于三个for循环的位置排序和最后对C中第i行第j列的数计算的位置放的位置不同。同时注意到在三个for循环之前都有一句#pragma omp parallel for 。这个是OpenMP中的一个指令,表示接下来的for循环将被多线程执行,这个语句要求几个for循环之间不能够有依赖。OpenMP 是 Open MultiProcessing 的缩写。OpenMP 并不是一个简单的函数库,而是一个诸多编译器支持的框架,或者说是协议吧,总之,不需要任何配置,你就可以在 Visual Studio 或者 gcc 中使用它了。OpenMP的设计们希望提供一种简单的方式让程序员不需要懂得创建和销毁线程就能写出多线程化程序。为此他们设计了一些pragma,指令和函数来让编译器能够在合适的地方插入线程大多数的循环只需要在for之前插入一个pragma就可以实现并行化。

我们这里的三个循环操作之间没依赖,因此完全可以使用openmp的pragma omp parallel for 来进行并行化使得操作更加快。

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值