这两天在看《算法导论》第4章,看到有关Strassen算法和递归式求解的内容,着实让小弟头疼了一把。经典就是经典,几十年前就出现的算法,到现在仍旧让人连连称赞,看来学生还得继续奋斗啊!!!
看了一天多,还是没把全部内容参透,就简单把Strassen算法跟大家分享一下,希望大家多多指点。
两个n介方阵相乘,最简单的算法就是使用三层for循环求解,时间复杂度是O(n3),在此不再详述。
第二种算法是使用分治算法,把方阵(假设n=2k)划分为4个n/2*n/2的的子方阵,如下所示:
A,B,C是三个n介方阵,要计算C=A*B,可把A, B, C分别分解为4个n/2*n/2的方阵,记为
A:A11 A12 B:B11 B12 C:C11 C12
A21 A22 B21 B22 C21 C22
因此,就有
C11 = A11*B11+A12*B21
C12 = A11*B12+A12*B22
C21 = A21*B11+A22*B21
C22 = A21*B12+A22*B22
每个公式包括两个乘法和一个加法,由此可以设计一个递归算法:
//矩阵加法:C=A+B,从C的(xC,yC)位置开始保存
int * squareMatrixAdd(int *A, int *B, int *C, int xC, int yC, int n)
{
for (int x1=0, x2=xC; x1<n; x1++, x2++)
{
for (int y1=0, y2=yC; y1<n; y1++, y2++)
{
C[x2*n*2+y2] = A[x1*n+y1]+B[x1*n+y1];
}
}
return C;
}
//分治方法:C=A*B, 分别从A的(xA, xB),B的(xB, yB)开始计算,从C的(xC,yC)开始保存,这样可借用下标就可分解矩阵
int * squareMatrixMultiplyRecursive(int *A, int xA, int yA, int *B, int xB, int yB, int *C, int xC, int yC, int n)
{
static int saveN = n;
if (1 == n)
{
C[xC+yC] = A[xA*saveN+yA] * B[xB*saveN+yB];
}
else
{
int n1 = n/2; //分解
int *tempC1 = new int[n1*n1];
int *tempC2 = new int[n1*n1];
squareMatrixMultiplyRecursive(A, xA, yA, B, xB, yB, tempC1, 0, 0, n1);//A11*B11
squareMatrixMultiplyRecursive(A, xA, yA+n1, B, xB+n1, yB, tempC2, 0, 0, n1);//A12*B21
squareMatrixAdd(tempC1, tempC2, C, xC, yC, n1);//C11=A11*B11+A12*B21
squareMatrixMultiplyRecursive(A, xA, yA, B, xB, yB+n1, tempC1, 0, 0, n1);//A11*B12
squareMatrixMultiplyRecursive(A, xA, yA+n1, B, xB+n1, yB+n1, tempC2, 0, 0, n1);//A12*B22
squareMatrixAdd(tempC1, tempC2, C, xC, yC+n1, n1);//C12=A11*B12+A12*B22
squareMatrixMultiplyRecursive(A, xA+n1, yA, B, xB, yB, tempC1, 0, 0, n1);//A21*B11
squareMatrixMultiplyRecursive(A, xA+n1, yA+n1, B, xB+n1, yB, tempC2, 0, 0, n1);//A22*B21
squareMatrixAdd(tempC1, tempC2, C, xC+n1, yC, n1);//C21=A21*B11+A22*B21
squareMatrixMultiplyRecursive(A, xA+n1, yA, B, xB, yB+n1, tempC1, 0, 0, n1);//A21*B12
squareMatrixMultiplyRecursive(A, xA+n1, yA+n1, B, xB+n1, yB+n1, tempC2, 0, 0, n1);//A22*B22
squareMatrixAdd(tempC1, tempC2, C, xC+n1, yC+n1, n1);//C22=A21*B12+A22*B22
delete []tempC1;
delete []tempC2;
}
return C;
}
由算法可以看出,把原问题可分成8子问题,每个子问题是原问题的1/2,时间复杂度但还是O(n3), 速度并没有增加,而且空间复杂度大大增加,因此这只是一种思想,使用价值不大。下面就说一下Strassen算法。
Strassen算法的核心是将第二种算法的8次乘法变为7次,当矩阵规模大于一定阈值时,将远远快于以上两种算法。Strassen算法分为四步,详细步骤在此不再详细叙述,大家可下去自己学习。
Strassen算法的时间复杂度是O(nlg7),因此比上述两种算法快了。其实Strassen算法并不是计算矩阵相乘的最好选择,它计算过程有着巨大的空间耗费和精度损失,计算稀疏矩阵不如普通方法好。但是对于稠密矩阵,矩阵规模大于一定阈值时,Strassen算法还是有着很令人满意的表现,这个阈值一般与具体的系统有关。
今天就写到这,感觉写的没什么水平,就算是自己的一点学习笔记,练习了一下编程。还请多多评论!
下面给出我自己的Strassen算法实现。
//计算从A的(xA,yA),B的(xB, yB)位置开始的n介矩阵的乘积 int ** Strassen(int **A, int xA, int yA, int **B, int xB, int yB, int n) { int **C = createMat(n); if (n==1) { C[0][0] = A[xA][yA] * B[xB][yB]; } else { n /= 2; //分解 int **S1 = createS(B, xB, yB+n, B, xB+n, yB+n, n, 1); int **S2 = createS(A, xA, yA, A, xA, yA+n, n, 0); int **S3 = createS(A, xA+n, yA, A, xA+n, yA+n, n, 0); int **S4 = createS(B, xB+n, yB, B, xB, yB, n, 1); int **S5 = createS(A, xA, yA, A, xA+n, yA+n, n, 0); int **S6 = createS(B, xB, yB, B, xB+n, yB+n, n, 0); int **S7 = createS(A, xA, yA+n, A, xA+n, yA+n, n, 1); int **S8 = createS(B, xB+n, yB, B, xB+n, yB+n, n, 0); int **S9 = createS(A, xA, yA, A, xA+n, yA, n, 1); int **S10 = createS(B, xB, yB, B, xB, yB+n, n, 0); int **P1 = Strassen(A, xA, yA, S1, 0, 0, n); int **P2 = Strassen(S2, 0, 0, B, xB+n, yB+n, n); int **P3 = Strassen(S3, 0, 0, B, xB, yB, n); int **P4 = Strassen(A, xA+n, yA+n, S4, 0, 0, n); int **P5 = Strassen(S5, 0, 0, S6, 0, 0, n); int **P6 = Strassen(S7, 0, 0, S8, 0, 0, n); int **P7 = Strassen(S9, 0, 0, S10, 0, 0, n); //C11 add(P5, 0, 0, P4, 0, 0, C, 0, 0, n); sub(C, 0, 0, P2, 0, 0, C, 0, 0, n); add(C, 0, 0, P6, 0, 0, C, 0, 0, n); //C12 add(P1, 0, 0, P2, 0, 0, C, 0, n, n); //C21 add(P3, 0, 0, P4, 0, 0, C, n, 0, n); //C22 add(P5, 0, 0, P1, 0, 0, C, n, n, n); sub(C, n, n, P3, 0, 0, C, n, n, n); sub(C, n, n, P7, 0, 0, C, n, n, n); deleteMat(S1, n); deleteMat(S2, n); deleteMat(S3, n); deleteMat(S4, n); deleteMat(S5, n); deleteMat(S6, n); deleteMat(S7, n); deleteMat(S8, n); deleteMat(S9, n); deleteMat(S10, n); deleteMat(P1, n); deleteMat(P2, n); deleteMat(P3, n); deleteMat(P4, n); deleteMat(P5, n); deleteMat(P6, n); deleteMat(P7, n); } return C; }
//矩阵加法 int ** add(int **A, int xA, int yA, int **B, int xB, int yB, int **C, int xC, int yC, int n) { for (int i=0; i<n; i++) { for (int j=0; j<n; j++) { C[xC+i][yC+j] = A[xA+i][yA+j] + B[xB+i][yB+j]; } } return C; } //矩阵减法 int **sub(int **A, int xA, int yA, int **B, int xB, int yB, int **C, int xC, int yC, int n) { for (int i=0; i<n; i++) { for (int j=0; j<n; j++) { C[xC+i][yC+j] = A[xA+i][yA+j] - B[xB+i][yB+j]; } } return C; } //创建矩阵 int **createMat(int n) { int **M = new int *[n]; if (!M) { exit(0); } for (int i=0; i<n; i++) { M[i] = new int[n]; if (!M[i]) { for (int j=0; j<i; j++) { delete[]M[j]; } delete []M; exit(1); } for (int j=0; j<n; j++) { M[i][j] = 0; } } return M; } //删除矩阵 void deleteMat(int **Mat, int n) { for (int i=0; i<n; i++) { delete []Mat[i]; } delete []Mat; Mat = NULL; } //创建S矩阵 int **createS(int **A, int xA, int yA, int **B, int xB, int yB, int n, int type) { int **S = createMat(n); if (0==type) { add(A, xA, yA, B, xB, yB, S, 0, 0, n); } else { sub(A, xA, yA, B, xB, yB, S, 0, 0, n); } return S; }