Strassen矩阵乘法(分治法续)

矩阵乘法

-1矩阵乘法的定义

矩阵乘法,A*B=C,其中:


那么乘法的定义呢??A矩阵的一行与B矩阵的一列点乘和为C的一个元素。用图形表示是最直观的,其定义就如下图:



-2基本矩阵乘法

那么由上面图中的公式,我们很容易得到基本矩阵相乘的伪代码:

for i = 1 to col // row
	for j = 1 to row // col
		tmp = 0
		for k = 1 to col 
			tmp += A[i*col + k]*B[j + k*row]
		end for
		C[i*col + j] = tmp
	end for
end for

row = col = n 的情况下,那么很容易看出上面的代码由三个for循环来构成的,那么这个代码的时间复杂度为O(n^3)

-3矩阵乘法的改进

怎样对矩阵相乘的算法进行改进呢???一个想当然的想法:分块矩阵相乘!!!


一共有8个(n/2)*(n/2)的矩阵乘法和,4个(n/2)*(n/2)的矩阵加法。再次使用以前的Master Method,

T(n) = 8T(n/2) + T(n^2) 


由此可见,算法的时间复杂度并没有下降,怎么办呢??

下面就到了伟大的Strassen’s Idea了。谁也不知道他是怎么想出来这个算法的,但是呢,一个指导思想是, 要想降低算法的时间复杂度,就要设法降低乘法的次数,这位Strassen做到了, 8次乘法减少到7次!

-4 Strassen矩阵乘法

直接上算法步骤:


将乘法从八次减少到了7次,这个差值1,看起来不起眼,但是这可是


那么我们再来直观地看看这个一次乘法的减少在理论上的性能提升:


由此可见,其在性能上的提升是有多么巨大,在MIT算法导论课件上说,这个算法在n>30时会显示出效果,但是这要跟编程方法有关的,算法好不等于实现性能好

算法的思路就是,当将矩阵分块再分块,当大小为2*2时计算,然后返回,但要注意的是计算每一个P的时候的乘法都是一个小的矩阵的相乘,也要用Strassen方法,所以这是一个递归方法。

-5 Strassen 矩阵乘法实现

算法的C代码如下:

主函数为:

/*Strassen*/
void StrassenMatrixMul(datatype *A,datatype*B,datatype *C, int row,int col)
{
	if (2 == row && 2 == col)//The terminate constraint.
	{
		int p1,p2,p3,p4,p5,p6,p7;
		/*p1~p7*/
		p1 = A[0]*(B[1]-B[3]) ;
		p2 = (A[0] + A[1])*B[3] ;
		p3 = (A[2] + A[3])*B[0] ;
		p4 = A[3]*(B[2] - B[0]) ;
		p5 = (A[0] + A[3])*(B[0] + B[3]) ;
		p6 = (A[1] - A[3])*(B[2] + B[3]) ;
		p7 = (A[0] - A[2])*(B[0] + B[1]) ;
		/*C*/
		C[0] = p5 + p4 - p2 + p6 ;
		C[1] = p1 + p2 ;
		C[2] = p3 + p4 ;
		C[3] = p5 + p1 -p3 - p7 ;
		return ;	
	}
	else
	{
		row = row/2 ;
		col = col/2 ;
		datatype *A1,*A2,*A3,*A4 ;
		/*Init matrix*/
		A1 = InitMatrix(row,col) ;
		A2 = InitMatrix(row,col) ;
		A3 = InitMatrix(row,col) ;
		A4 = InitMatrix(row,col) ;
		/*divide matrix A into [A1, A2
								A3, A4]*/
		MatrixQuarter(A,A1,row*2,col*2,1) ;
		MatrixQuarter(A,A2,row*2,col*2,2) ;
		MatrixQuarter(A,A3,row*2,col*2,3) ;
		MatrixQuarter(A,A4,row*2,col*2,4) ;

		datatype *B1,*B2,*B3,*B4 ;
		B1 = InitMatrix(row,col) ;
		B2 = InitMatrix(row,col) ;
		B3 = InitMatrix(row,col) ;
		B4 = InitMatrix(row,col) ;
		MatrixQuarter(B,B1,row*2,col*2,1) ;
		MatrixQuarter(B,B2,row*2,col*2,2) ;
		MatrixQuarter(B,B3,row*2,col*2,3) ;
		MatrixQuarter(B,B4,row*2,col*2,4) ;

		datatype *C1,*C2,*C3,*C4 ;
		C1 = InitMatrix(row,col) ;
		C2 = InitMatrix(row,col) ;
		C3 = InitMatrix(row,col) ;
		C4 = InitMatrix(row,col) ;

		/*the Ps*/
		datatype *P1,*P2,*P3,*P4,*P5,*P6,*P7;
		P1 = InitMatrix(row,col) ;
		P2 = InitMatrix(row,col) ;
		P3 = InitMatrix(row,col) ;
		P4 = InitMatrix(row,col) ;
		P5 = InitMatrix(row,col) ;
		P6 = InitMatrix(row,col) ;
		P7 = InitMatrix(row,col) ;

		datatype *tmp1, *tmp2;
		tmp1 = InitMatrix(row,col) ;
		tmp2 = InitMatrix(row,col) ;
		
		/*p1*/
		MatrixMinus(B2,B4,tmp1,row,col) ;
		StrassenMatrixMul(A1,tmp1,P1,row,col) ;
		/*p2*/
		MatrixAdd(A1,A2,tmp1,row,col) ;
		StrassenMatrixMul(tmp1,B4,P2,row,col) ;
		/*p3*/
		MatrixAdd(A3,A4,tmp1,row,col) ;
		StrassenMatrixMul(tmp1,B1,P3,row,col) ;
		/*p4*/
		MatrixMinus(B3,B1,tmp1,row,col) ;
		StrassenMatrixMul(A4,tmp1,P4,row,col) ;
		/*p5*/
		MatrixAdd(A1,A4,tmp1,row,col) ;
		MatrixAdd(B1,B4,tmp2,row,col) ;
		StrassenMatrixMul(tmp1,tmp2,P5,row,col) ;
		/*p6*/
		MatrixMinus(A2,A4,tmp1,row,col) ;
		MatrixAdd(B3,B4,tmp2,row,col) ;
		StrassenMatrixMul(tmp1,tmp2,P6,row,col) ;
		/*p7*/
		MatrixMinus(A1,A3,tmp1,row,col) ;
		MatrixAdd(B1,B2,tmp2,row,col) ;
		StrassenMatrixMul(tmp1,tmp2,P7,row,col) ;

		/*C1*/
		MatrixAdd(P5,P4,tmp1,row,col) ;
		MatrixMinus(tmp1,P2,tmp2,row,col) ;
		MatrixAdd(tmp2,P6,C1,row,col) ;
		/*C2*/
		MatrixAdd(P1,P2,C2,row,col) ;
		/*C3*/
		MatrixAdd(P3,P4,C3,row,col) ;
		/*C4*/
		MatrixAdd(P5,P1,tmp1,row,col) ;
		MatrixMinus(tmp1,P3,tmp2,row,col) ;
		MatrixMinus(tmp2,P7,C4,row,col) ;
		/*C1,C2,C3,C4 integrate into C.*/
		MatrixIntegrate(C1,C2,C3,C4,C,row,col) ;

		/*free*/
		free(A1) ;
		free(A2) ;
		free(A3) ;
		free(A4) ;
		free(B1) ;
		free(B2) ;
		free(B3) ;
		free(B4) ;
		free(C1) ;
		free(C2) ;
		free(C3) ;
		free(C4) ;
		free(P1) ;
		free(P2) ;
		free(P3) ;
		free(P4) ;
		free(P5) ;
		free(P6) ;
		free(P7) ;
		free(tmp1) ;
		free(tmp2) ;

		return ;
	}
}

辅助函数1:InitMatrix,初始化矩阵指针:

/*InitMatrix*/
datatype *InitMatrix(int row,int col)
{
	size_t size = sizeof(datatype)*row*col ;
	datatype *p ;
	if (NULL == (p = (datatype *)malloc(size)))
	{
		printf("Allocation storage failed!\n") ;
		return NULL ;
	}
	else
	{
		return p ;
	}
}

辅助函数2:MatrixQuarter,意思为取矩阵的1/4,按第四个参数来决定是哪1/4,看主函数的注释很容易理解:

/*MatrixQuarter*/
//Get 1/4 elements of A into B, indicator is 1,2,3,4.
void MatrixQuarter(datatype *A,datatype*B, int row, int col, int indicator)
{
	int row2 = row/2 ;
	int col2 = col/2 ;
	int i, j ;
	switch(indicator)
	{
	//[r,s
	// t,q], get r
	case 1 :
		{
			for (i = 0; i < row2; i ++)
			{
				for (j = 0; j < col2; j++)
				{
					B[i*col2 + j] = A[i*col + j] ;
				}
			}
			break;
		}
	//[r,s
	// t,q], get s	
	case 2 :
		{
			for (i = 0; i < row2; i ++)
			{
				for (j = 0; j < col2; j++)
				{
					B[i*col2 + j] = A[i*col + j + col2] ;
				}
			}
			break;
		}
	//[r,s
	// t,q], get t
	case 3 :
		{
			for (i = 0; i < row2; i ++)
			{
				for (j = 0; j < col2; j++)
				{
					B[i*col2 + j] = A[(row2+i)*col + j] ;
				}
			}
			break;
		}
	case 4 :
		{
			for (i = 0; i < row2; i ++)
			{
				for (j = 0; j < col2; j++)
				{
					B[i*col2 + j] = A[(row2+i)*col + j + col2] ;
				}
			}
			break;
		}
	default :
		printf("Wrong indicator!\n") ;

	}
}
辅助函数3:MatrixMinus,A-B= C:

/*MatrixMinus*/
void MatrixMinus(datatype *A,datatype*B,datatype *C, int row,int col)
{
	for (int i = 0; i < row*col; i ++)
	{
		C[i] = A[i] - B[i] ;
	}
}
辅助函数4:MatrixAdd,A + B = C:

/*MatrixAdd*/
void MatrixAdd(datatype *A,datatype*B,datatype *C,int row,int col)
{
	for (int i = 0; i < row*col; i ++)
	{
		C[i] = A[i] + B[i] ;
	}	
}
辅助函数5:MatrixIntegrate,与MatrixQuarter对应,将4个子矩阵合成一个大矩阵:

/*MatrixIntegrate*/
//row is the row before integration
void MatrixIntegrate(datatype *A1,datatype *A2,datatype *A3, datatype *A4,datatype *A, int row, int col)
{
	for (int i = 0; i < row; i ++)
	{
		for (int j = 0; j < col; j ++)
		{
			A[i*col*2 + j] = A1[i*col + j] ;
			A[i*col*2 + j + col] = A2[i*col + j] ;
			A[(row+i)*col*2 + j] = A3[i*col + j] ;
			A[(row+i)*col*2 + j + col] = A4[i*col + j] ;
		}
	}
}

上述的方法在Vs2010上是运行通过的,datatype可以声明为任意的,我运行的时候是声明为了int型,需要注意的是:这个实现为了使输入矩阵的大小可变,所有的矩阵存储都是动态申请的,所以上述实现真正运行起来速度甚至比普通的慢很多,这里只为了实现而实现,并没有针对特定应用进行优化,还是那一句,好算法不一定性能好,还得看程序

-6 其他有用资料

下面的第一个文章对于Strassen算法的原理讲解的更加透彻,而第二个文章则总结的很好,条理更加清晰:

http://www.ituring.com.cn/article/17978

http://mindlee.net/2011/11/21/matrix-multiply/








  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值