# 矩阵乘法（java实现）

矩阵的运算中经常要用到矩阵的乘法，通常情况下矩阵的乘法用计算机高级语言来实现有三种办法。

第一种是最普通的，最容易想到的办法，即用三个for循坏来实现。它的时间复杂度为O(n^3)。具体实现如下：

public class Matrixmultiply {

public static void main(String[] args) {

int A[][] = new int[][] {
{1, 1, 1, 1, 1, 1, 1, 1},
{2, 2, 2, 2, 2, 2, 2, 2},
{3, 3, 3, 3, 3, 3, 3, 3},
{4, 4, 4, 4, 4, 4, 4, 4},
{5, 5, 5, 5, 5, 5, 5, 5},
{6, 6, 6, 6, 6, 6, 6, 6},
{7, 7, 7, 7, 7, 7, 7, 7},
{8, 8, 8, 8, 8, 8, 8, 8},
};

int B[][] = new int[][] {
{8,8,8,8,8,8,8,8},
{7,7,7,7,7,7,7,7},
{6,6,6,6,6,6,6,6},
{5,5,5,5,5,5,5,5},
{4,4,4,4,4,4,4,4},
{3,3,3,3,3,3,3,3},
{2,2,2,2,2,2,2,2},
{1,1,1,1,1,1,1,1},
};
int[][] C = new int[8][8];
for(int i=0;i<8;i++){
for(int j=0;j<8;j++){
for(int k=0;k<8;k++){
C[i][j] =C[i][j] + A[i][k] * A[k][j];
}
}
}
System.out.println("普通矩阵乘法");
for(int i=0;i<8;i++){
for(int j=0;j<8;j++){
System.out.print(C[i][j] + " ");
}
System.out.println();
}
}
}

可以用递归树的方法求得其时间复杂度仍为O(n^3)。其代码实现为

public class Recursivematrixmultiply {
public static void displayMatrix(int matrix[][]) {
int n = matrix.length;
for (int i=0;i<n; i++) {
for (int j=0;j<n;j++) {
System.out.print(matrix[i][j] + " ");
}
System.out.println();
}
}

public static void copyToMatrix(int srcMatrix[][], int startI, int startJ,
int destMatrix[][]) {
int n = destMatrix.length;
for (int i = startI; i < startI+n; i++) {
for (int j = startJ; j <startJ+n ; j++) {
destMatrix[i - startI][j - startJ] = srcMatrix[i][j];
}
}
}
public static void copyFromMatrix(int destMatrix[][], int startI, int startJ,
int srcMatrix[][]) {
int n = srcMatrix.length;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
destMatrix[startI + i][startJ + j] = srcMatrix[i][j];
}
}
}
public static void MatrixAdd(int A[][],int B[][],int C[][]){
int n = A.length;
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
C[i][j] = A[i][j] + B[i][j];
}
}
}
public static void MatrixSub(int A[][], int B[][], int C[][]) {
int n = A.length;
for (int i=0;i<n;i++) {
for (int j=0;j<n;j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
}

public static int[][] squareMatrixMultiplyRecursive(int A[][], int B[][]) {
int n = A.length;
int C[][] = new int[n][n];
if (n == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int A11[][], A12[][], A21[][], A22[][];
int B11[][], B12[][], B21[][], B22[][];
int C11[][], C12[][], C21[][], C22[][];

A11 = new int[n/2][n/2];A12 = new int[n/2][n/2];A21 = new int[n/2][n/2];A22 = new int[n/2][n/2];
copyToMatrix(A, 0, 0, A11);
copyToMatrix(A, 0, n/2,  A12);
copyToMatrix(A, n/2, 0,  A21);
copyToMatrix(A, n/2, n/2, A22);

B11 = new int[n/2][n/2];B12 = new int[n/2][n/2];B21 = new int[n/2][n/2];B22 = new int[n/2][n/2];
copyToMatrix(B, 0, 0, B11);
copyToMatrix(B, 0, n/2, B12);
copyToMatrix(B, n/2, 0, B21);
copyToMatrix(B, n/2, n/2, B22);

C11 = new int[n/2][n/2];C12 = new int[n/2][n/2];C21 = new int[n/2][n/2];C22 = new int[n/2][n/2];
C11);
C12);
C21);
C22);
copyFromMatrix(C, 0, 0, C11);
copyFromMatrix(C, 0, n/2, C12);
copyFromMatrix(C, n/2, 0, C21);
copyFromMatrix(C, n/2, n/2, C22);
}

return C;
}
public static void main(String[] args) {
int A[][] = new int[][] {
{1, 1, 1, 1, 1, 1, 1, 1},
{2, 2, 2, 2, 2, 2, 2, 2},
{3, 3, 3, 3, 3, 3, 3, 3},
{4, 4, 4, 4, 4, 4, 4, 4},
{5, 5, 5, 5, 5, 5, 5, 5},
{6, 6, 6, 6, 6, 6, 6, 6},
{7, 7, 7, 7, 7, 7, 7, 7},
{8, 8, 8, 8, 8, 8, 8, 8},
};

int B[][] = new int[][] {
{8,8,8,8,8,8,8,8},
{7,7,7,7,7,7,7,7},
{6,6,6,6,6,6,6,6},
{5,5,5,5,5,5,5,5},
{4,4,4,4,4,4,4,4},
{3,3,3,3,3,3,3,3},
{2,2,2,2,2,2,2,2},
{1,1,1,1,1,1,1,1},
};
System.out.println("递归矩阵乘法");
int C[][] = squareMatrixMultiplyRecursive(A,B);
displayMatrix(C);
}
} 

strassen算法细节如下

C11=P5+P4-P2+P6

C12=P1+P2

C21=P3+P4

C22=P5+P1-P3-P7

  public class Strassen {

public static void displayMatrix(int matrix[][]) {
int n = matrix.length;
for (int i=0;i<n; i++) {
for (int j=0;j<n;j++) {
System.out.print(matrix[i][j] + " ");
}
System.out.println();
}
}

public static void copyToMatrix(int srcMatrix[][], int startI, int startJ,
int destMatrix[][]) {
int n = destMatrix.length;
for (int i = startI; i < startI+n; i++) {
for (int j = startJ; j <startJ+n ; j++) {
destMatrix[i - startI][j - startJ] = srcMatrix[i][j];
}
}
}
public static void copyFromMatrix(int destMatrix[][], int startI, int startJ,
int srcMatrix[][]) {
int n = srcMatrix.length;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
destMatrix[startI + i][startJ + j] = srcMatrix[i][j];
}
}
}
public static void MatrixAdd(int A[][],int B[][],int C[][]){
int n = A.length;
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
C[i][j] = A[i][j] + B[i][j];
}
}
}
public static void MatrixSub(int A[][], int B[][], int C[][]) {
int n = A.length;
for (int i=0;i<n;i++) {
for (int j=0;j<n;j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
}

public static int[][] strassenMatrixMultiplyRecursive(int[][] A,int[][] B){
int n = A.length;
int[][] C = new int[n][n];
if(n==1)
C[0][0] = A[0][0] * B[0][0];
else{
int[][] A11,A12,A21,A22;
int[][] B11,B12,B21,B22;
int[][] S1,S2,S3,S4,S5,S6,S7,S8,S9,S10;
int[][] P1,P2,P3,P4,P5,P6,P7;
int[][] C11,C12,C21,C22;

A11 = new int[n/2][n/2];
A12 = new int[n/2][n/2];
A21 = new int[n/2][n/2];
A22 = new int[n/2][n/2];
copyToMatrix(A,0,0,A11);
copyToMatrix(A,0,n/2,A12);
copyToMatrix(A,n/2,0,A21);
copyToMatrix(A,n/2,n/2,A22);

B11 = new int[n/2][n/2];
B12 = new int[n/2][n/2];
B21 = new int[n/2][n/2];
B22 = new int[n/2][n/2];
copyToMatrix(B,0,0,B11);
copyToMatrix(B,0,n/2,B12);
copyToMatrix(B,n/2,0,B21);
copyToMatrix(B,n/2,n/2,B22);

S1 = new int[n/2][n/2];
S2 = new int[n/2][n/2];
S3 = new int[n/2][n/2];
S4 = new int[n/2][n/2];
S5 = new int[n/2][n/2];
S6 = new int[n/2][n/2];
S7 = new int[n/2][n/2];
S8 = new int[n/2][n/2];
S9 = new int[n/2][n/2];
S10 = new int[n/2][n/2];
MatrixSub(B12,B22,S1);
MatrixSub(B21,B11,S4);
MatrixSub(A12,A22,S7);
MatrixSub(A11,A21,S9);

P1 = new int[n/2][n/2];P2 = new int[n/2][n/2];P3 = new int[n/2][n/2];P4 = new int[n/2][n/2];
P5 = new int[n/2][n/2];P6 = new int[n/2][n/2];P7 = new int[n/2][n/2];
P1 = strassenMatrixMultiplyRecursive(A11, S1);
P2 = strassenMatrixMultiplyRecursive(S2, B22);
P3 = strassenMatrixMultiplyRecursive(S3, B11);
P4 = strassenMatrixMultiplyRecursive(A22, S4);
P5 = strassenMatrixMultiplyRecursive(S5, S6);
P6 = strassenMatrixMultiplyRecursive(S7, S8);
P7 = strassenMatrixMultiplyRecursive(S9, S10);

C11 = new int[n/2][n/2];
C12 = new int[n/2][n/2];
C21 = new int[n/2][n/2];
C22 = new int[n/2][n/2];
int[][] temp = new int[n/2][n/2];
MatrixSub(temp,P2,temp);

MatrixSub(temp, P3, temp);
MatrixSub(temp, P7, C22);

copyFromMatrix(C,0,0,C11);
copyFromMatrix(C,0,n/2,C12);
copyFromMatrix(C,n/2,0,C21);
copyFromMatrix(C,n/2,n/2,C22);
}
return C;
}
public static void main(String[] args) {
int A[][] = new int[][] {
{1, 1, 1, 1, 1, 1, 1, 1},
{2, 2, 2, 2, 2, 2, 2, 2},
{3, 3, 3, 3, 3, 3, 3, 3},
{4, 4, 4, 4, 4, 4, 4, 4},
{5, 5, 5, 5, 5, 5, 5, 5},
{6, 6, 6, 6, 6, 6, 6, 6},
{7, 7, 7, 7, 7, 7, 7, 7},
{8, 8, 8, 8, 8, 8, 8, 8},
};

int B[][] = new int[][] {
{8,8,8,8,8,8,8,8},
{7,7,7,7,7,7,7,7},
{6,6,6,6,6,6,6,6},
{5,5,5,5,5,5,5,5},
{4,4,4,4,4,4,4,4},
{3,3,3,3,3,3,3,3},
{2,2,2,2,2,2,2,2},
{1,1,1,1,1,1,1,1},
};
System.out.println("Strassen 递归矩阵乘法");
int[][] C = new int[8][8];
C = strassenMatrixMultiplyRecursive(A,B);
displayMatrix(C);
}
}