矩阵的运算中经常要用到矩阵的乘法,通常情况下矩阵的乘法用计算机高级语言来实现有三种办法。
第一种是最普通的,最容易想到的办法,即用三个for循坏来实现。它的时间复杂度为O(n^3)。具体实现如下:
public class Matrixmultiply {
public static void main(String[] args) {
int A[][] = new int[][] {
{1, 1, 1, 1, 1, 1, 1, 1},
{2, 2, 2, 2, 2, 2, 2, 2},
{3, 3, 3, 3, 3, 3, 3, 3},
{4, 4, 4, 4, 4, 4, 4, 4},
{5, 5, 5, 5, 5, 5, 5, 5},
{6, 6, 6, 6, 6, 6, 6, 6},
{7, 7, 7, 7, 7, 7, 7, 7},
{8, 8, 8, 8, 8, 8, 8, 8},
};
int B[][] = new int[][] {
{8,8,8,8,8,8,8,8},
{7,7,7,7,7,7,7,7},
{6,6,6,6,6,6,6,6},
{5,5,5,5,5,5,5,5},
{4,4,4,4,4,4,4,4},
{3,3,3,3,3,3,3,3},
{2,2,2,2,2,2,2,2},
{1,1,1,1,1,1,1,1},
};
int[][] C = new int[8][8];
for(int i=0;i<8;i++){
for(int j=0;j<8;j++){
for(int k=0;k<8;k++){
C[i][j] =C[i][j] + A[i][k] * A[k][j];
}
}
}
System.out.println("普通矩阵乘法");
for(int i=0;i<8;i++){
for(int j=0;j<8;j++){
System.out.print(C[i][j] + " ");
}
System.out.println();
}
}
}
你最初可能觉得任何矩阵乘法都要花费这么多时间,因为矩阵乘法的自然定义就是如此。但这是错误的,如果用分治的思想,看看我们能不能将其时间复杂度简化。
为简单起见,用分治法时我们假定三个矩阵均为n*n矩阵,其中n为2的幂。
在计算C11等矩阵时,我们又遇到了两个矩阵相乘所以可以用递归来解决。这个程序的时间复杂度为T(n)=8T(n/2)+O(n^2)。因为每次递归都需要做八次乘法,因此为8T(n/2)。同时我们还要计算四次矩阵加法,每个矩阵包含n^2/4个元素,因此为O(n^2)。所以它的递归式为:
可以用递归树的方法求得其时间复杂度仍为O(n^3)。其代码实现为
public class Recursivematrixmultiply {
public static void displayMatrix(int matrix[][]) {
int n = matrix.length;
for (int i=0;i<n; i++) {
for (int j=0;j<n;j++) {
System.out.print(matrix[i][j] + " ");
}
System.out.println();
}
}
public static void copyToMatrix(int srcMatrix[][], int startI, int startJ,
int destMatrix[][]) {
int n = destMatrix.length;
for (int i = startI; i < startI+n; i++) {
for (int j = startJ; j <startJ+n ; j++) {
destMatrix[i - startI][j - startJ] = srcMatrix[i][j];
}
}
}
public static void copyFromMatrix(int destMatrix[][], int startI, int startJ,
int srcMatrix[][]) {
int n = srcMatrix.length;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
destMatrix[startI + i][startJ + j] = srcMatrix[i][j];
}
}
}
public static void MatrixAdd(int A[][],int B[][],int C[][]){
int n = A.length;
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
C[i][j] = A[i][j] + B[i][j];
}
}
}
public static void MatrixSub(int A[][], int B[][], int C[][]) {
int n = A.length;
for (int i=0;i<n;i++) {
for (int j=0;j<n;j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
}
public static int[][] squareMatrixMultiplyRecursive(int A[][], int B[][]) {
int n = A.length;
int C[][] = new int[n][n];
if (n == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int A11[][], A12[][], A21[][], A22[][];
int B11[][], B12[][], B21[][], B22[][];
int C11[][], C12[][], C21[][], C22[][];
A11 = new int[n/2][n/2];A12 = new int[n/2][n/2];A21 = new int[n/2][n/2];A22 = new int[n/2][n/2];
copyToMatrix(A, 0, 0, A11);
copyToMatrix(A, 0, n/2, A12);
copyToMatrix(A, n/2, 0, A21);
copyToMatrix(A, n/2, n/2, A22);
B11 = new int[n/2][n/2];B12 = new int[n/2][n/2];B21 = new int[n/2][n/2];B22 = new int[n/2][n/2];
copyToMatrix(B, 0, 0, B11);
copyToMatrix(B, 0, n/2, B12);
copyToMatrix(B, n/2, 0, B21);
copyToMatrix(B, n/2, n/2, B22);
C11 = new int[n/2][n/2];C12 = new int[n/2][n/2];C21 = new int[n/2][n/2];C22 = new int[n/2][n/2];
MatrixAdd(squareMatrixMultiplyRecursive(A11, B11), squareMatrixMultiplyRecursive(A12, B21),
C11);
MatrixAdd(squareMatrixMultiplyRecursive(A11, B12), squareMatrixMultiplyRecursive(A12, B22),
C12);
MatrixAdd(squareMatrixMultiplyRecursive(A21, B11), squareMatrixMultiplyRecursive(A22, B21),
C21);
MatrixAdd(squareMatrixMultiplyRecursive(A21, B12), squareMatrixMultiplyRecursive(A22, B22),
C22);
copyFromMatrix(C, 0, 0, C11);
copyFromMatrix(C, 0, n/2, C12);
copyFromMatrix(C, n/2, 0, C21);
copyFromMatrix(C, n/2, n/2, C22);
}
return C;
}
public static void main(String[] args) {
int A[][] = new int[][] {
{1, 1, 1, 1, 1, 1, 1, 1},
{2, 2, 2, 2, 2, 2, 2, 2},
{3, 3, 3, 3, 3, 3, 3, 3},
{4, 4, 4, 4, 4, 4, 4, 4},
{5, 5, 5, 5, 5, 5, 5, 5},
{6, 6, 6, 6, 6, 6, 6, 6},
{7, 7, 7, 7, 7, 7, 7, 7},
{8, 8, 8, 8, 8, 8, 8, 8},
};
int B[][] = new int[][] {
{8,8,8,8,8,8,8,8},
{7,7,7,7,7,7,7,7},
{6,6,6,6,6,6,6,6},
{5,5,5,5,5,5,5,5},
{4,4,4,4,4,4,4,4},
{3,3,3,3,3,3,3,3},
{2,2,2,2,2,2,2,2},
{1,1,1,1,1,1,1,1},
};
System.out.println("递归矩阵乘法");
int C[][] = squareMatrixMultiplyRecursive(A,B);
displayMatrix(C);
}
}
下面我们用strassen方法来求矩阵乘法,它将递归的8次乘法简化为7次。其代价是增加了几次额外的加法,但这对整个算法影响不大。strassen的算法运行时间递归式为
可以求出其解为
strassen算法细节如下
首先创建10个矩阵
递归地计算七次乘法计算
最后对P矩阵进行加减法运算
C11=P5+P4-P2+P6
C12=P1+P2
C21=P3+P4
C22=P5+P1-P3-P7
其代码实现为
public class Strassen {
public static void displayMatrix(int matrix[][]) {
int n = matrix.length;
for (int i=0;i<n; i++) {
for (int j=0;j<n;j++) {
System.out.print(matrix[i][j] + " ");
}
System.out.println();
}
}
public static void copyToMatrix(int srcMatrix[][], int startI, int startJ,
int destMatrix[][]) {
int n = destMatrix.length;
for (int i = startI; i < startI+n; i++) {
for (int j = startJ; j <startJ+n ; j++) {
destMatrix[i - startI][j - startJ] = srcMatrix[i][j];
}
}
}
public static void copyFromMatrix(int destMatrix[][], int startI, int startJ,
int srcMatrix[][]) {
int n = srcMatrix.length;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
destMatrix[startI + i][startJ + j] = srcMatrix[i][j];
}
}
}
public static void MatrixAdd(int A[][],int B[][],int C[][]){
int n = A.length;
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
C[i][j] = A[i][j] + B[i][j];
}
}
}
public static void MatrixSub(int A[][], int B[][], int C[][]) {
int n = A.length;
for (int i=0;i<n;i++) {
for (int j=0;j<n;j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
}
public static int[][] strassenMatrixMultiplyRecursive(int[][] A,int[][] B){
int n = A.length;
int[][] C = new int[n][n];
if(n==1)
C[0][0] = A[0][0] * B[0][0];
else{
int[][] A11,A12,A21,A22;
int[][] B11,B12,B21,B22;
int[][] S1,S2,S3,S4,S5,S6,S7,S8,S9,S10;
int[][] P1,P2,P3,P4,P5,P6,P7;
int[][] C11,C12,C21,C22;
A11 = new int[n/2][n/2];
A12 = new int[n/2][n/2];
A21 = new int[n/2][n/2];
A22 = new int[n/2][n/2];
copyToMatrix(A,0,0,A11);
copyToMatrix(A,0,n/2,A12);
copyToMatrix(A,n/2,0,A21);
copyToMatrix(A,n/2,n/2,A22);
B11 = new int[n/2][n/2];
B12 = new int[n/2][n/2];
B21 = new int[n/2][n/2];
B22 = new int[n/2][n/2];
copyToMatrix(B,0,0,B11);
copyToMatrix(B,0,n/2,B12);
copyToMatrix(B,n/2,0,B21);
copyToMatrix(B,n/2,n/2,B22);
S1 = new int[n/2][n/2];
S2 = new int[n/2][n/2];
S3 = new int[n/2][n/2];
S4 = new int[n/2][n/2];
S5 = new int[n/2][n/2];
S6 = new int[n/2][n/2];
S7 = new int[n/2][n/2];
S8 = new int[n/2][n/2];
S9 = new int[n/2][n/2];
S10 = new int[n/2][n/2];
MatrixSub(B12,B22,S1);
MatrixAdd(A11,A12,S2);
MatrixAdd(A21,A22,S3);
MatrixSub(B21,B11,S4);
MatrixAdd(A11,A22,S5);
MatrixAdd(B11,B22,S6);
MatrixSub(A12,A22,S7);
MatrixAdd(B21,B22,S8);
MatrixSub(A11,A21,S9);
MatrixAdd(B11,B12,S10);
P1 = new int[n/2][n/2];P2 = new int[n/2][n/2];P3 = new int[n/2][n/2];P4 = new int[n/2][n/2];
P5 = new int[n/2][n/2];P6 = new int[n/2][n/2];P7 = new int[n/2][n/2];
P1 = strassenMatrixMultiplyRecursive(A11, S1);
P2 = strassenMatrixMultiplyRecursive(S2, B22);
P3 = strassenMatrixMultiplyRecursive(S3, B11);
P4 = strassenMatrixMultiplyRecursive(A22, S4);
P5 = strassenMatrixMultiplyRecursive(S5, S6);
P6 = strassenMatrixMultiplyRecursive(S7, S8);
P7 = strassenMatrixMultiplyRecursive(S9, S10);
C11 = new int[n/2][n/2];
C12 = new int[n/2][n/2];
C21 = new int[n/2][n/2];
C22 = new int[n/2][n/2];
int[][] temp = new int[n/2][n/2];
MatrixAdd(P5,P4,temp);
MatrixSub(temp,P2,temp);
MatrixAdd(temp,P6,C11);
MatrixAdd(P1, P2, C12);
MatrixAdd(P3, P4, C21);
MatrixAdd(P5, P1, temp);
MatrixSub(temp, P3, temp);
MatrixSub(temp, P7, C22);
copyFromMatrix(C,0,0,C11);
copyFromMatrix(C,0,n/2,C12);
copyFromMatrix(C,n/2,0,C21);
copyFromMatrix(C,n/2,n/2,C22);
}
return C;
}
public static void main(String[] args) {
int A[][] = new int[][] {
{1, 1, 1, 1, 1, 1, 1, 1},
{2, 2, 2, 2, 2, 2, 2, 2},
{3, 3, 3, 3, 3, 3, 3, 3},
{4, 4, 4, 4, 4, 4, 4, 4},
{5, 5, 5, 5, 5, 5, 5, 5},
{6, 6, 6, 6, 6, 6, 6, 6},
{7, 7, 7, 7, 7, 7, 7, 7},
{8, 8, 8, 8, 8, 8, 8, 8},
};
int B[][] = new int[][] {
{8,8,8,8,8,8,8,8},
{7,7,7,7,7,7,7,7},
{6,6,6,6,6,6,6,6},
{5,5,5,5,5,5,5,5},
{4,4,4,4,4,4,4,4},
{3,3,3,3,3,3,3,3},
{2,2,2,2,2,2,2,2},
{1,1,1,1,1,1,1,1},
};
System.out.println("Strassen 递归矩阵乘法");
int[][] C = new int[8][8];
C = strassenMatrixMultiplyRecursive(A,B);
displayMatrix(C);
}
}