矩阵乘法(java实现)

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Kigznlun/article/details/79942950

 矩阵的运算中经常要用到矩阵的乘法,通常情况下矩阵的乘法用计算机高级语言来实现有三种办法。

  第一种是最普通的,最容易想到的办法,即用三个for循坏来实现。它的时间复杂度为O(n^3)。具体实现如下:

public class Matrixmultiply {

	public static void main(String[] args) {
		
        int A[][] = new int[][] {  
            {1, 1, 1, 1, 1, 1, 1, 1},  
            {2, 2, 2, 2, 2, 2, 2, 2},  
            {3, 3, 3, 3, 3, 3, 3, 3},  
            {4, 4, 4, 4, 4, 4, 4, 4},  
            {5, 5, 5, 5, 5, 5, 5, 5},  
            {6, 6, 6, 6, 6, 6, 6, 6},  
            {7, 7, 7, 7, 7, 7, 7, 7},  
            {8, 8, 8, 8, 8, 8, 8, 8},  
        };  
          
         int B[][] = new int[][] {  
            {8,8,8,8,8,8,8,8},  
            {7,7,7,7,7,7,7,7},  
            {6,6,6,6,6,6,6,6},  
            {5,5,5,5,5,5,5,5},  
            {4,4,4,4,4,4,4,4},  
            {3,3,3,3,3,3,3,3},  
            {2,2,2,2,2,2,2,2},  
            {1,1,1,1,1,1,1,1},  
        };
      int[][] C = new int[8][8];
      for(int i=0;i<8;i++){
    	  for(int j=0;j<8;j++){
    		  for(int k=0;k<8;k++){
    			  C[i][j] =C[i][j] + A[i][k] * A[k][j];
    		  }
    	  }
      }
      System.out.println("普通矩阵乘法");
      for(int i=0;i<8;i++){
     	 for(int j=0;j<8;j++){
     		 System.out.print(C[i][j] + " ");
     	 }
     	 System.out.println();
      }	
	}
}

你最初可能觉得任何矩阵乘法都要花费这么多时间,因为矩阵乘法的自然定义就是如此。但这是错误的,如果用分治的思想,看看我们能不能将其时间复杂度简化。

为简单起见,用分治法时我们假定三个矩阵均为n*n矩阵,其中n为2的幂。

在计算C11等矩阵时,我们又遇到了两个矩阵相乘所以可以用递归来解决。这个程序的时间复杂度为T(n)=8T(n/2)+O(n^2)。因为每次递归都需要做八次乘法,因此为8T(n/2)。同时我们还要计算四次矩阵加法,每个矩阵包含n^2/4个元素,因此为O(n^2)。所以它的递归式为:

  可以用递归树的方法求得其时间复杂度仍为O(n^3)。其代码实现为

public class Recursivematrixmultiply {
	   public static void displayMatrix(int matrix[][]) { 
	    	int n = matrix.length;
	        for (int i=0;i<n; i++) {  
	            for (int j=0;j<n;j++) {  
	                System.out.print(matrix[i][j] + " ");  
	            }  
	            System.out.println();  
	        }  
	    }  
		
	    public static void copyToMatrix(int srcMatrix[][], int startI, int startJ,   
	            int destMatrix[][]) { 
	    	int n = destMatrix.length;
	        for (int i = startI; i < startI+n; i++) {  
	            for (int j = startJ; j <startJ+n ; j++) {  
	                destMatrix[i - startI][j - startJ] = srcMatrix[i][j];   
	            }  
	        }  
	    } 
	    public static void copyFromMatrix(int destMatrix[][], int startI, int startJ,  
	            int srcMatrix[][]) {  
	    	int n = srcMatrix.length;
	        for (int i = 0; i < n; i++) {  
	            for (int j = 0; j < n; j++) {  
	                destMatrix[startI + i][startJ + j] = srcMatrix[i][j];   
	            }  
	        }  
	    } 
	    public static void MatrixAdd(int A[][],int B[][],int C[][]){
	        int n = A.length;
	    	for(int i=0;i<n;i++){
	    		for(int j=0;j<n;j++){
	    			C[i][j] = A[i][j] + B[i][j];
	    		}
	    	}	
	    }
	    public static void MatrixSub(int A[][], int B[][], int C[][]) {  
	    	int n = A.length;
	        for (int i=0;i<n;i++) {  
	            for (int j=0;j<n;j++) {  
	                C[i][j] = A[i][j] - B[i][j];  
	            }  
	        }  
	    } 
		  
	      
	    public static int[][] squareMatrixMultiplyRecursive(int A[][], int B[][]) {  
	        int n = A.length;  
	        int C[][] = new int[n][n];  
	        if (n == 1) {  
	            C[0][0] = A[0][0] * B[0][0];  
	        } else {  
	            int A11[][], A12[][], A21[][], A22[][];  
	            int B11[][], B12[][], B21[][], B22[][];  
	            int C11[][], C12[][], C21[][], C22[][];  
	              
	            A11 = new int[n/2][n/2];A12 = new int[n/2][n/2];A21 = new int[n/2][n/2];A22 = new int[n/2][n/2];  
	            copyToMatrix(A, 0, 0, A11);  
	            copyToMatrix(A, 0, n/2,  A12);  
	            copyToMatrix(A, n/2, 0,  A21);  
	            copyToMatrix(A, n/2, n/2, A22);  
	              
	            B11 = new int[n/2][n/2];B12 = new int[n/2][n/2];B21 = new int[n/2][n/2];B22 = new int[n/2][n/2];  
	            copyToMatrix(B, 0, 0, B11);  
	            copyToMatrix(B, 0, n/2, B12);  
	            copyToMatrix(B, n/2, 0, B21);  
	            copyToMatrix(B, n/2, n/2, B22);  
	              
	            C11 = new int[n/2][n/2];C12 = new int[n/2][n/2];C21 = new int[n/2][n/2];C22 = new int[n/2][n/2];  
	            MatrixAdd(squareMatrixMultiplyRecursive(A11, B11), squareMatrixMultiplyRecursive(A12, B21),  
	                    C11);  
	            MatrixAdd(squareMatrixMultiplyRecursive(A11, B12), squareMatrixMultiplyRecursive(A12, B22),  
	                    C12);  
	            MatrixAdd(squareMatrixMultiplyRecursive(A21, B11), squareMatrixMultiplyRecursive(A22, B21),  
	                    C21);  
	            MatrixAdd(squareMatrixMultiplyRecursive(A21, B12), squareMatrixMultiplyRecursive(A22, B22),  
	                    C22);  
	            copyFromMatrix(C, 0, 0, C11);  
	            copyFromMatrix(C, 0, n/2, C12);  
	            copyFromMatrix(C, n/2, 0, C21);  
	            copyFromMatrix(C, n/2, n/2, C22);  
	        }  
	          
	        return C;  
	    } 
	    public static void main(String[] args) { 
	         int A[][] = new int[][] {  
	             {1, 1, 1, 1, 1, 1, 1, 1},  
	             {2, 2, 2, 2, 2, 2, 2, 2},  
	             {3, 3, 3, 3, 3, 3, 3, 3},  
	             {4, 4, 4, 4, 4, 4, 4, 4},  
	             {5, 5, 5, 5, 5, 5, 5, 5},  
	             {6, 6, 6, 6, 6, 6, 6, 6},  
	             {7, 7, 7, 7, 7, 7, 7, 7},  
	             {8, 8, 8, 8, 8, 8, 8, 8},  
	         };  
	           
	          int B[][] = new int[][] {  
	             {8,8,8,8,8,8,8,8},  
	             {7,7,7,7,7,7,7,7},  
	             {6,6,6,6,6,6,6,6},  
	             {5,5,5,5,5,5,5,5},  
	             {4,4,4,4,4,4,4,4},  
	             {3,3,3,3,3,3,3,3},  
	             {2,2,2,2,2,2,2,2},  
	             {1,1,1,1,1,1,1,1},  
	         };
       System.out.println("递归矩阵乘法");  
	        int C[][] = squareMatrixMultiplyRecursive(A,B);  
	        displayMatrix(C);    
	    }  
	} 

下面我们用strassen方法来求矩阵乘法,它将递归的8次乘法简化为7次。其代价是增加了几次额外的加法,但这对整个算法影响不大。strassen的算法运行时间递归式为

 

可以求出其解为

strassen算法细节如下

首先创建10个矩阵

递归地计算七次乘法计算

最后对P矩阵进行加减法运算

C11=P5+P4-P2+P6

C12=P1+P2

C21=P3+P4

C22=P5+P1-P3-P7

其代码实现为

  public class Strassen {
 
    public static void displayMatrix(int matrix[][]) { 
    int n = matrix.length;
        for (int i=0;i<n; i++) {  
            for (int j=0;j<n;j++) {  
                System.out.print(matrix[i][j] + " ");  
            }  
            System.out.println();  
        }  
    }  

    public static void copyToMatrix(int srcMatrix[][], int startI, int startJ,   
            int destMatrix[][]) { 
    int n = destMatrix.length;
        for (int i = startI; i < startI+n; i++) {  
            for (int j = startJ; j <startJ+n ; j++) {  
                destMatrix[i - startI][j - startJ] = srcMatrix[i][j];   
            }  
        }  
    } 
    public static void copyFromMatrix(int destMatrix[][], int startI, int startJ,  
            int srcMatrix[][]) {  
    int n = srcMatrix.length;
        for (int i = 0; i < n; i++) {  
            for (int j = 0; j < n; j++) {  
                destMatrix[startI + i][startJ + j] = srcMatrix[i][j];   
            }  
        }  
    } 
    public static void MatrixAdd(int A[][],int B[][],int C[][]){
        int n = A.length;
    for(int i=0;i<n;i++){
    for(int j=0;j<n;j++){
    C[i][j] = A[i][j] + B[i][j];
    }
    } 
    }
    public static void MatrixSub(int A[][], int B[][], int C[][]) {  
    int n = A.length;
        for (int i=0;i<n;i++) {  
            for (int j=0;j<n;j++) {  
                C[i][j] = A[i][j] - B[i][j];  
            }  
        }  
    } 

    public static int[][] strassenMatrixMultiplyRecursive(int[][] A,int[][] B){
 int n = A.length;
 int[][] C = new int[n][n];
 if(n==1)
   C[0][0] = A[0][0] * B[0][0];
 else{
 int[][] A11,A12,A21,A22; 
 int[][] B11,B12,B21,B22;
 int[][] S1,S2,S3,S4,S5,S6,S7,S8,S9,S10;
 int[][] P1,P2,P3,P4,P5,P6,P7;
 int[][] C11,C12,C21,C22;
 
 A11 = new int[n/2][n/2];
 A12 = new int[n/2][n/2];
 A21 = new int[n/2][n/2];
 A22 = new int[n/2][n/2];
 copyToMatrix(A,0,0,A11);
 copyToMatrix(A,0,n/2,A12);
 copyToMatrix(A,n/2,0,A21);
 copyToMatrix(A,n/2,n/2,A22);


 B11 = new int[n/2][n/2];
 B12 = new int[n/2][n/2];
 B21 = new int[n/2][n/2];
 B22 = new int[n/2][n/2];
 copyToMatrix(B,0,0,B11);
 copyToMatrix(B,0,n/2,B12);
 copyToMatrix(B,n/2,0,B21);
 copyToMatrix(B,n/2,n/2,B22);
 
 S1 = new int[n/2][n/2];
 S2 = new int[n/2][n/2];
 S3 = new int[n/2][n/2];
 S4 = new int[n/2][n/2];
 S5 = new int[n/2][n/2];
 S6 = new int[n/2][n/2];
 S7 = new int[n/2][n/2];
 S8 = new int[n/2][n/2];
 S9 = new int[n/2][n/2];
 S10 = new int[n/2][n/2];
 MatrixSub(B12,B22,S1);
 MatrixAdd(A11,A12,S2);
 MatrixAdd(A21,A22,S3);
 MatrixSub(B21,B11,S4);
 MatrixAdd(A11,A22,S5);
 MatrixAdd(B11,B22,S6);
 MatrixSub(A12,A22,S7);
 MatrixAdd(B21,B22,S8);
 MatrixSub(A11,A21,S9);
 MatrixAdd(B11,B12,S10);
 
 P1 = new int[n/2][n/2];P2 = new int[n/2][n/2];P3 = new int[n/2][n/2];P4 = new int[n/2][n/2];  
          P5 = new int[n/2][n/2];P6 = new int[n/2][n/2];P7 = new int[n/2][n/2];
          P1 = strassenMatrixMultiplyRecursive(A11, S1);  
          P2 = strassenMatrixMultiplyRecursive(S2, B22);  
          P3 = strassenMatrixMultiplyRecursive(S3, B11);  
          P4 = strassenMatrixMultiplyRecursive(A22, S4);  
          P5 = strassenMatrixMultiplyRecursive(S5, S6);  
          P6 = strassenMatrixMultiplyRecursive(S7, S8);  
          P7 = strassenMatrixMultiplyRecursive(S9, S10);
          
          C11 = new int[n/2][n/2];
          C12 = new int[n/2][n/2];
          C21 = new int[n/2][n/2];
          C22 = new int[n/2][n/2];
          int[][] temp = new int[n/2][n/2];
          MatrixAdd(P5,P4,temp);
          MatrixSub(temp,P2,temp);
          MatrixAdd(temp,P6,C11);
          
          MatrixAdd(P1, P2, C12);  
          MatrixAdd(P3, P4, C21);  
            
          MatrixAdd(P5, P1, temp);  
          MatrixSub(temp, P3, temp);  
          MatrixSub(temp, P7, C22);
          
          copyFromMatrix(C,0,0,C11);
          copyFromMatrix(C,0,n/2,C12);
          copyFromMatrix(C,n/2,0,C21);
          copyFromMatrix(C,n/2,n/2,C22);   
 }
 return C;  
  }
    public static void main(String[] args) { 
         int A[][] = new int[][] {  
            {1, 1, 1, 1, 1, 1, 1, 1},  
            {2, 2, 2, 2, 2, 2, 2, 2},  
            {3, 3, 3, 3, 3, 3, 3, 3},  
            {4, 4, 4, 4, 4, 4, 4, 4},  
            {5, 5, 5, 5, 5, 5, 5, 5},  
            {6, 6, 6, 6, 6, 6, 6, 6},  
            {7, 7, 7, 7, 7, 7, 7, 7},  
            {8, 8, 8, 8, 8, 8, 8, 8},  
        };  
          
         int B[][] = new int[][] {  
            {8,8,8,8,8,8,8,8},  
            {7,7,7,7,7,7,7,7},  
            {6,6,6,6,6,6,6,6},  
            {5,5,5,5,5,5,5,5},  
            {4,4,4,4,4,4,4,4},  
            {3,3,3,3,3,3,3,3},  
            {2,2,2,2,2,2,2,2},  
            {1,1,1,1,1,1,1,1},  
        };
        System.out.println("Strassen 递归矩阵乘法");  
        int[][] C = new int[8][8];
        C = strassenMatrixMultiplyRecursive(A,B);  
        displayMatrix(C);
    }
}

 

展开阅读全文

没有更多推荐了,返回首页