矩阵乘法
-1矩阵乘法的定义
矩阵乘法,A*B=C,其中:
那么乘法的定义呢??A矩阵的一行与B矩阵的一列点乘和为C的一个元素。用图形表示是最直观的,其定义就如下图:
-2基本矩阵乘法
那么由上面图中的公式,我们很容易得到基本矩阵相乘的伪代码:
for i = 1 to col // row
for j = 1 to row // col
tmp = 0
for k = 1 to col
tmp += A[i*col + k]*B[j + k*row]
end for
C[i*col + j] = tmp
end for
end for
在row = col = n 的情况下,那么很容易看出上面的代码由三个for循环来构成的,那么这个代码的时间复杂度为O(n^3)。
-3矩阵乘法的改进
怎样对矩阵相乘的算法进行改进呢???一个想当然的想法:分块矩阵相乘!!!
一共有8个(n/2)*(n/2)的矩阵乘法和,4个(n/2)*(n/2)的矩阵加法。再次使用以前的Master Method,
T(n) = 8T(n/2) + T(n^2)
由此可见,算法的时间复杂度并没有下降,怎么办呢??
下面就到了伟大的Strassen’s Idea了。谁也不知道他是怎么想出来这个算法的,但是呢,一个指导思想是, 要想降低算法的时间复杂度,就要设法降低乘法的次数,这位Strassen做到了, 将8次乘法减少到7次!-4 Strassen矩阵乘法
直接上算法步骤:
将乘法从八次减少到了7次,这个差值1,看起来不起眼,但是这可是
那么我们再来直观地看看这个一次乘法的减少在理论上的性能提升:
由此可见,其在性能上的提升是有多么巨大,在MIT算法导论课件上说,这个算法在n>30时会显示出效果,但是这要跟编程方法有关的,算法好不等于实现性能好。
算法的思路就是,当将矩阵分块再分块,当大小为2*2时计算,然后返回,但要注意的是计算每一个P的时候的乘法都是一个小的矩阵的相乘,也要用Strassen方法,所以这是一个递归方法。-5 Strassen 矩阵乘法实现
算法的C代码如下:
主函数为:
/*Strassen*/
void StrassenMatrixMul(datatype *A,datatype*B,datatype *C, int row,int col)
{
if (2 == row && 2 == col)//The terminate constraint.
{
int p1,p2,p3,p4,p5,p6,p7;
/*p1~p7*/
p1 = A[0]*(B[1]-B[3]) ;
p2 = (A[0] + A[1])*B[3] ;
p3 = (A[2] + A[3])*B[0] ;
p4 = A[3]*(B[2] - B[0]) ;
p5 = (A[0] + A[3])*(B[0] + B[3]) ;
p6 = (A[1] - A[3])*(B[2] + B[3]) ;
p7 = (A[0] - A[2])*(B[0] + B[1]) ;
/*C*/
C[0] = p5 + p4 - p2 + p6 ;
C[1] = p1 + p2 ;
C[2] = p3 + p4 ;
C[3] = p5 + p1 -p3 - p7 ;
return ;
}
else
{
row = row/2 ;
col = col/2 ;
datatype *A1,*A2,*A3,*A4 ;
/*Init matrix*/
A1 = InitMatrix(row,col) ;
A2 = InitMatrix(row,col) ;
A3 = InitMatrix(row,col) ;
A4 = InitMatrix(row,col) ;
/*divide matrix A into [A1, A2
A3, A4]*/
MatrixQuarter(A,A1,row*2,col*2,1) ;
MatrixQuarter(A,A2,row*2,col*2,2) ;
MatrixQuarter(A,A3,row*2,col*2,3) ;
MatrixQuarter(A,A4,row*2,col*2,4) ;
datatype *B1,*B2,*B3,*B4 ;
B1 = InitMatrix(row,col) ;
B2 = InitMatrix(row,col) ;
B3 = InitMatrix(row,col) ;
B4 = InitMatrix(row,col) ;
MatrixQuarter(B,B1,row*2,col*2,1) ;
MatrixQuarter(B,B2,row*2,col*2,2) ;
MatrixQuarter(B,B3,row*2,col*2,3) ;
MatrixQuarter(B,B4,row*2,col*2,4) ;
datatype *C1,*C2,*C3,*C4 ;
C1 = InitMatrix(row,col) ;
C2 = InitMatrix(row,col) ;
C3 = InitMatrix(row,col) ;
C4 = InitMatrix(row,col) ;
/*the Ps*/
datatype *P1,*P2,*P3,*P4,*P5,*P6,*P7;
P1 = InitMatrix(row,col) ;
P2 = InitMatrix(row,col) ;
P3 = InitMatrix(row,col) ;
P4 = InitMatrix(row,col) ;
P5 = InitMatrix(row,col) ;
P6 = InitMatrix(row,col) ;
P7 = InitMatrix(row,col) ;
datatype *tmp1, *tmp2;
tmp1 = InitMatrix(row,col) ;
tmp2 = InitMatrix(row,col) ;
/*p1*/
MatrixMinus(B2,B4,tmp1,row,col) ;
StrassenMatrixMul(A1,tmp1,P1,row,col) ;
/*p2*/
MatrixAdd(A1,A2,tmp1,row,col) ;
StrassenMatrixMul(tmp1,B4,P2,row,col) ;
/*p3*/
MatrixAdd(A3,A4,tmp1,row,col) ;
StrassenMatrixMul(tmp1,B1,P3,row,col) ;
/*p4*/
MatrixMinus(B3,B1,tmp1,row,col) ;
StrassenMatrixMul(A4,tmp1,P4,row,col) ;
/*p5*/
MatrixAdd(A1,A4,tmp1,row,col) ;
MatrixAdd(B1,B4,tmp2,row,col) ;
StrassenMatrixMul(tmp1,tmp2,P5,row,col) ;
/*p6*/
MatrixMinus(A2,A4,tmp1,row,col) ;
MatrixAdd(B3,B4,tmp2,row,col) ;
StrassenMatrixMul(tmp1,tmp2,P6,row,col) ;
/*p7*/
MatrixMinus(A1,A3,tmp1,row,col) ;
MatrixAdd(B1,B2,tmp2,row,col) ;
StrassenMatrixMul(tmp1,tmp2,P7,row,col) ;
/*C1*/
MatrixAdd(P5,P4,tmp1,row,col) ;
MatrixMinus(tmp1,P2,tmp2,row,col) ;
MatrixAdd(tmp2,P6,C1,row,col) ;
/*C2*/
MatrixAdd(P1,P2,C2,row,col) ;
/*C3*/
MatrixAdd(P3,P4,C3,row,col) ;
/*C4*/
MatrixAdd(P5,P1,tmp1,row,col) ;
MatrixMinus(tmp1,P3,tmp2,row,col) ;
MatrixMinus(tmp2,P7,C4,row,col) ;
/*C1,C2,C3,C4 integrate into C.*/
MatrixIntegrate(C1,C2,C3,C4,C,row,col) ;
/*free*/
free(A1) ;
free(A2) ;
free(A3) ;
free(A4) ;
free(B1) ;
free(B2) ;
free(B3) ;
free(B4) ;
free(C1) ;
free(C2) ;
free(C3) ;
free(C4) ;
free(P1) ;
free(P2) ;
free(P3) ;
free(P4) ;
free(P5) ;
free(P6) ;
free(P7) ;
free(tmp1) ;
free(tmp2) ;
return ;
}
}
辅助函数1:InitMatrix,初始化矩阵指针:
/*InitMatrix*/
datatype *InitMatrix(int row,int col)
{
size_t size = sizeof(datatype)*row*col ;
datatype *p ;
if (NULL == (p = (datatype *)malloc(size)))
{
printf("Allocation storage failed!\n") ;
return NULL ;
}
else
{
return p ;
}
}
辅助函数2:MatrixQuarter,意思为取矩阵的1/4,按第四个参数来决定是哪1/4,看主函数的注释很容易理解:
/*MatrixQuarter*/
//Get 1/4 elements of A into B, indicator is 1,2,3,4.
void MatrixQuarter(datatype *A,datatype*B, int row, int col, int indicator)
{
int row2 = row/2 ;
int col2 = col/2 ;
int i, j ;
switch(indicator)
{
//[r,s
// t,q], get r
case 1 :
{
for (i = 0; i < row2; i ++)
{
for (j = 0; j < col2; j++)
{
B[i*col2 + j] = A[i*col + j] ;
}
}
break;
}
//[r,s
// t,q], get s
case 2 :
{
for (i = 0; i < row2; i ++)
{
for (j = 0; j < col2; j++)
{
B[i*col2 + j] = A[i*col + j + col2] ;
}
}
break;
}
//[r,s
// t,q], get t
case 3 :
{
for (i = 0; i < row2; i ++)
{
for (j = 0; j < col2; j++)
{
B[i*col2 + j] = A[(row2+i)*col + j] ;
}
}
break;
}
case 4 :
{
for (i = 0; i < row2; i ++)
{
for (j = 0; j < col2; j++)
{
B[i*col2 + j] = A[(row2+i)*col + j + col2] ;
}
}
break;
}
default :
printf("Wrong indicator!\n") ;
}
}
辅助函数3:MatrixMinus,A-B= C:
/*MatrixMinus*/
void MatrixMinus(datatype *A,datatype*B,datatype *C, int row,int col)
{
for (int i = 0; i < row*col; i ++)
{
C[i] = A[i] - B[i] ;
}
}
辅助函数4:MatrixAdd,A + B = C:
/*MatrixAdd*/
void MatrixAdd(datatype *A,datatype*B,datatype *C,int row,int col)
{
for (int i = 0; i < row*col; i ++)
{
C[i] = A[i] + B[i] ;
}
}
辅助函数5:MatrixIntegrate,与MatrixQuarter对应,将4个子矩阵合成一个大矩阵:
/*MatrixIntegrate*/
//row is the row before integration
void MatrixIntegrate(datatype *A1,datatype *A2,datatype *A3, datatype *A4,datatype *A, int row, int col)
{
for (int i = 0; i < row; i ++)
{
for (int j = 0; j < col; j ++)
{
A[i*col*2 + j] = A1[i*col + j] ;
A[i*col*2 + j + col] = A2[i*col + j] ;
A[(row+i)*col*2 + j] = A3[i*col + j] ;
A[(row+i)*col*2 + j + col] = A4[i*col + j] ;
}
}
}
上述的方法在Vs2010上是运行通过的,datatype可以声明为任意的,我运行的时候是声明为了int型,需要注意的是:这个实现为了使输入矩阵的大小可变,所有的矩阵存储都是动态申请的,所以上述实现真正运行起来速度甚至比普通的慢很多,这里只为了实现而实现,并没有针对特定应用进行优化,还是那一句,好算法不一定性能好,还得看程序。
-6 其他有用资料
下面的第一个文章对于Strassen算法的原理讲解的更加透彻,而第二个文章则总结的很好,条理更加清晰:
http://www.ituring.com.cn/article/17978