分治算法求解n介方阵相乘(Strassen)

这两天在看《算法导论》第4章,看到有关Strassen算法和递归式求解的内容,着实让小弟头疼了一把。经典就是经典,几十年前就出现的算法,到现在仍旧让人连连称赞,看来学生还得继续奋斗啊!!!

看了一天多,还是没把全部内容参透,就简单把Strassen算法跟大家分享一下,希望大家多多指点。

两个n介方阵相乘,最简单的算法就是使用三层for循环求解,时间复杂度是O(n3),在此不再详述。

第二种算法是使用分治算法,把方阵(假设n=2k)划分为4个n/2*n/2的的子方阵,如下所示:

A,B,C是三个n介方阵,要计算C=A*B,可把A, B, C分别分解为4个n/2*n/2的方阵记为

A:A11  A12    B:B11  B12    C:C11  C12

       A21 A22         B21  B22          C21  C22

因此,就有

C11 = A11*B11+A12*B21

C12 = A11*B12+A12*B22

C21 = A21*B11+A22*B21

C22 = A21*B12+A22*B22  

每个公式包括两个乘法和一个加法,由此可以设计一个递归算法:

//矩阵加法:C=A+B,从C的(xC,yC)位置开始保存
int * squareMatrixAdd(int *A, int *B, int *C, int xC, int yC, int n) 
{
	for (int x1=0, x2=xC; x1<n; x1++, x2++)
	{
		for (int y1=0, y2=yC; y1<n; y1++, y2++)
		{
			C[x2*n*2+y2] = A[x1*n+y1]+B[x1*n+y1];
		}
	}
	return C;
}
//分治方法:C=A*B, 分别从A的(xA, xB),B的(xB, yB)开始计算,从C的(xC,yC)开始保存,这样可借用下标就可分解矩阵
int * squareMatrixMultiplyRecursive(int *A, int xA, int yA, int *B, int xB, int yB, int *C, int xC, int yC, int n)
{
	static int saveN = n;
	if (1 == n)
	{
		C[xC+yC] = A[xA*saveN+yA] * B[xB*saveN+yB];
	}
	else
	{
		int n1 = n/2; //分解
		int *tempC1 = new int[n1*n1];
		int *tempC2 = new int[n1*n1];
		squareMatrixMultiplyRecursive(A, xA, yA, B, xB, yB, tempC1, 0, 0, n1);//A11*B11
		squareMatrixMultiplyRecursive(A, xA, yA+n1, B, xB+n1, yB, tempC2, 0, 0, n1);//A12*B21
		squareMatrixAdd(tempC1, tempC2, C, xC, yC, n1);//C11=A11*B11+A12*B21

		squareMatrixMultiplyRecursive(A, xA, yA, B, xB, yB+n1, tempC1, 0, 0, n1);//A11*B12
		squareMatrixMultiplyRecursive(A, xA, yA+n1, B, xB+n1, yB+n1, tempC2, 0, 0, n1);//A12*B22
		squareMatrixAdd(tempC1, tempC2, C, xC, yC+n1, n1);//C12=A11*B12+A12*B22

		squareMatrixMultiplyRecursive(A, xA+n1, yA, B, xB, yB, tempC1, 0, 0, n1);//A21*B11
		squareMatrixMultiplyRecursive(A, xA+n1, yA+n1, B, xB+n1, yB, tempC2, 0, 0, n1);//A22*B21
		squareMatrixAdd(tempC1, tempC2, C, xC+n1, yC, n1);//C21=A21*B11+A22*B21

		squareMatrixMultiplyRecursive(A, xA+n1, yA, B, xB, yB+n1, tempC1, 0, 0, n1);//A21*B12
		squareMatrixMultiplyRecursive(A, xA+n1, yA+n1, B, xB+n1, yB+n1, tempC2, 0, 0, n1);//A22*B22
		squareMatrixAdd(tempC1, tempC2, C, xC+n1, yC+n1, n1);//C22=A21*B12+A22*B22

		delete []tempC1;
		delete []tempC2;
	}
	return C;
}


由算法可以看出,把原问题可分成8子问题,每个子问题是原问题的1/2,时间复杂度但还是O(n3), 速度并没有增加,而且空间复杂度大大增加,因此这只是一种思想,使用价值不大。下面就说一下Strassen算法。

Strassen算法的核心是将第二种算法的8次乘法变为7次,当矩阵规模大于一定阈值时,将远远快于以上两种算法。Strassen算法分为四步,详细步骤在此不再详细叙述,大家可下去自己学习。

Strassen算法的时间复杂度是O(nlg7),因此比上述两种算法快了。其实Strassen算法并不是计算矩阵相乘的最好选择,它计算过程有着巨大的空间耗费和精度损失,计算稀疏矩阵不如普通方法好。但是对于稠密矩阵,矩阵规模大于一定阈值时,Strassen算法还是有着很令人满意的表现,这个阈值一般与具体的系统有关。

今天就写到这,感觉写的没什么水平,就算是自己的一点学习笔记,练习了一下编程。还请多多评论!

下面给出我自己的Strassen算法实现。

//计算从A的(xA,yA),B的(xB, yB)位置开始的n介矩阵的乘积
int ** Strassen(int **A, int xA, int yA, int **B, int xB, int yB, int n)
{
	int **C = createMat(n);
	if (n==1)
	{
		C[0][0] = A[xA][yA] * B[xB][yB];
	}
	else
	{
		n /= 2; //分解
		int **S1 = createS(B, xB, yB+n, B, xB+n, yB+n, n, 1);
		int **S2 = createS(A, xA, yA, A, xA, yA+n, n, 0);
		int **S3 = createS(A, xA+n, yA, A, xA+n, yA+n, n, 0); 
		int **S4 = createS(B, xB+n, yB, B, xB, yB, n, 1); 
		int **S5 = createS(A, xA, yA, A, xA+n, yA+n, n, 0);
		int **S6 = createS(B, xB, yB, B, xB+n, yB+n, n, 0);
		int **S7 = createS(A, xA, yA+n, A, xA+n, yA+n, n, 1);
		int **S8 = createS(B, xB+n, yB, B, xB+n, yB+n, n, 0);
		int **S9 = createS(A, xA, yA, A, xA+n, yA, n, 1);
		int **S10 = createS(B, xB, yB, B, xB, yB+n, n, 0);

		int **P1 = Strassen(A, xA, yA, S1, 0, 0, n);
		int **P2 = Strassen(S2, 0, 0, B, xB+n, yB+n, n);
		int **P3 = Strassen(S3, 0, 0, B, xB, yB, n);
		int **P4 = Strassen(A, xA+n, yA+n, S4, 0, 0, n);
		int **P5 = Strassen(S5, 0, 0, S6, 0, 0, n);
		int **P6 = Strassen(S7, 0, 0, S8, 0, 0, n);
		int **P7 = Strassen(S9, 0, 0, S10, 0, 0, n);

		//C11
		add(P5, 0, 0, P4, 0, 0, C, 0, 0, n);
		sub(C, 0, 0, P2, 0, 0, C, 0, 0, n);
		add(C, 0, 0, P6, 0, 0, C, 0, 0, n);
		//C12
		add(P1, 0, 0, P2, 0, 0, C, 0, n, n);
		//C21
		add(P3, 0, 0, P4, 0, 0, C, n, 0, n);
		//C22
		add(P5, 0, 0, P1, 0, 0, C, n, n, n);
		sub(C, n, n, P3, 0, 0, C, n, n, n);
		sub(C, n, n, P7, 0, 0, C, n, n, n);

		deleteMat(S1, n);
		deleteMat(S2, n);
		deleteMat(S3, n);
		deleteMat(S4, n);
		deleteMat(S5, n);
		deleteMat(S6, n);
		deleteMat(S7, n);
		deleteMat(S8, n);
		deleteMat(S9, n);
		deleteMat(S10, n);
		deleteMat(P1, n);
		deleteMat(P2, n);
		deleteMat(P3, n);
		deleteMat(P4, n);
		deleteMat(P5, n);
		deleteMat(P6, n);
		deleteMat(P7, n);
	}
	return C;
}
//矩阵加法
int ** add(int **A, int xA, int yA, int **B, int xB, int yB, int **C, int xC, int yC, int n)
{
	for (int i=0; i<n; i++)
	{
		for (int j=0; j<n; j++)
		{
			C[xC+i][yC+j] = A[xA+i][yA+j] + B[xB+i][yB+j];
		}
	}
	return C;
}
//矩阵减法
int **sub(int **A, int xA, int yA, int **B, int xB, int yB, int **C, int xC, int yC, int n)
{
	for (int i=0; i<n; i++)
	{
		for (int j=0; j<n; j++)
		{
			C[xC+i][yC+j] = A[xA+i][yA+j] - B[xB+i][yB+j];
		}
	}
	return C;
}
//创建矩阵
int **createMat(int n)
{
	int **M = new int *[n];
	if (!M)
	{
		exit(0);
	}
	for (int i=0; i<n; i++)
	{
		M[i] = new int[n];
		if (!M[i])
		{
			for (int j=0; j<i; j++)
			{
				delete[]M[j];
			}
			delete []M;
			exit(1);
		}
		for (int j=0; j<n; j++)
		{
			M[i][j] = 0;
		}
	}
	return M;
}
//删除矩阵
void deleteMat(int **Mat, int n)
{
	for (int i=0; i<n; i++)
	{
		delete []Mat[i];
	}
	delete []Mat;
	Mat = NULL;
}
//创建S矩阵
int **createS(int **A, int xA, int yA, int **B, int xB, int yB, int n, int type)
{
	int **S = createMat(n);
	if (0==type)
	{
		add(A, xA, yA, B, xB, yB, S, 0, 0, n);
	}
	else
	{
		sub(A, xA, yA, B, xB, yB, S, 0, 0, n);
	}
	return S;
}


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值