darknet对卷积计算的处理实际上是:
先使用im2col将input_channel*(height*width)的输入特征图(实际存储是按照行存储的,即是1*(input_channel*height*width)的一维数组)转化成(input_channel*kernel_size*kernel_size)*(out_height*out_width)的特征矩阵,这里同样是按行存储的。
之后通过gemm函数实现通用的矩阵乘法实现卷积计算,即让卷积核矩阵*im2col后的输入特征矩阵。其中卷积核的大小为(kenel_channel)*(inut_channel*kernel_size*kernel_size)
最后得到kernel_channel*(out_height*out_width)即卷积输出的最终结果。
首先来看gemm.h,看其中的gemm_cpu,实现的是C=ALPHA*A*B + BETA*C操作,这里的BETA*C表示加上偏置项。通过其传入的参数,我们会发现gemm_cpu会对通过TA和TB变量来判断是否对卷积核矩阵A和输入的经过im2col转换过的特征矩阵B进行转置操作,因此真正实现的时候会根据不同的转置方式来采取不同的gemm方法。由于我们的A与B的存储方式都是一维数组,因此输入参数还要包含A与B的行与列。gemm_cpu的输入参数解释如下
/*
** 功能:矩阵计算,完成C = ALPHA * A * B + BETA * C,
** 输出的C也是按行存储(所有行并成一行)
** 输入: A,B,C 输入矩阵(一维数组格式,按行存储,所有行并成一行)
** ALPHA 系数
** BETA 系数
** TA,TB 是否需要对A,B做转置操作,是为1,否为0
** M A,C的行数
** N B,C的列数
** K A的列数,B的行数
** lda A的列数(不做转置)或者行数(做转置)
** ldb B的列数(不做转置)或者行数(做转置)
** ldc C的列数
*/
void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float BETA,
float *C, int ldc)
void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float BETA,
float *C, int ldc)
{
//printf("cpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc);
int i, j;
/*首先完成BETA * C的操作*/
for(i = 0; i < M; ++i){
for(j = 0; j < N; ++j){
C[i*ldc + j] *= BETA;
}
}
/*根据指定的TA和TB来选择不同的矩阵乘法方法,如gemm_nn就代表A与B都不进行转置操作的情况,gemm_tn代表对A进行转置*/
if(!TA && !TB)
gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
else if(TA && !TB)
gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
else if(!TA && TB)
gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
else
gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
}
/*
** M:C的行数,因为这里A没有做转置换操作,因此这里A的行数是M
** N:C的列数,因为这里B也没有做转置操作,因此这里B的列数是N
** K:这里都没有转置,因此K代表A的列数,B的行数
** lda: 不转置时该变量是A的列数,因此A的列数是lda
** ldb: 不转置时该变量时B的行数,因此B的行数是ldb
** ldc: C的列数
*/
void gemm_nn(int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float *C, int ldc)
{
int i,j,k;
#pragma omp parallel for
// 遍历A的每一行
for(i = 0; i < M; ++i){
// 遍历A的每一列
for(k = 0; k < K; ++k){
// 首先将A_PART*A的操作做完
register float A_PART = ALPHA*A[i*lda+k];
// 使用A_PART的第i行的所有数与B第k列的所有数做乘加操作
for(j = 0; j < N; ++j){
C[i*ldc+j] += A_PART*B[k*ldb+j];
}
}
}
}
/*
** M:C的行数,因为这里A做转置换操作,因此这里A的列数是M
** N:C的列数,因为这里B做转置操作,因此这里B的行数是N
** K:这里都转置,因此K代表A的行数,B的列数
** lda: 不转置时该变量是A的列数,因此A的行数是lda
** ldb: 不转置时该变量时B的行数,因此B的列数是ldb
** ldc: C的列数
*/
void gemm_tt(int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float *C, int ldc)
{
int i,j,k;
#pragma omp parallel for
// 这里遍历的是A的列数,C的行数
for(i = 0; i < M; ++i){
// 遍历的是C的列数,B的行数
for(j = 0; j < N; ++j){
register float sum = 0;
// 遍历A的行和B的列
for(k = 0; k < K; ++k){
sum += ALPHA*A[i+k*lda]*B[k+j*ldb];
}
C[i*ldc+j] += sum;
}
}
}
可以看到这几个gemm之间的区别就在于三个for循环的位置排序和最后对C中第i行第j列的数计算的位置放的位置不同。同时注意到在三个for循环之前都有一句#pragma omp parallel for 。这个是OpenMP中的一个指令,表示接下来的for循环将被多线程执行,这个语句要求几个for循环之间不能够有依赖。OpenMP 是 Open MultiProcessing 的缩写。OpenMP 并不是一个简单的函数库,而是一个诸多编译器支持的框架,或者说是协议吧,总之,不需要任何配置,你就可以在 Visual Studio 或者 gcc 中使用它了。OpenMP的设计们希望提供一种简单的方式让程序员不需要懂得创建和销毁线程就能写出多线程化程序。为此他们设计了一些pragma,指令和函数来让编译器能够在合适的地方插入线程大多数的循环只需要在for之前插入一个pragma就可以实现并行化。
我们这里的三个循环操作之间没依赖,因此完全可以使用openmp的pragma omp parallel for 来进行并行化使得操作更加快。