最简单的矩阵乘法实现
public static int[][] squareMatrixMultiply(int[][] arr1,int [][] arr2) {
int n = arr1.length;
int[][] arr3 = new int [n][n];
for(int i = 0;i < n;i++) {
for(int j = 0;j < n;j++) {
for(int k = 0;k < n;k++) {
arr3[i][j] += arr1[i][k] * arr2[k][j];
}
}
}
return arr3;
}
这种方法的实现是最简单的,时间复杂度也是最长的,是Θ(n^3).
递归实现矩阵乘法
矩阵划分代码
public static void divideMatrix(int[][] arr,int[][] arr11,int[][] arr12,int[][] arr21,int[][] arr22){
int n = arr.length;
int mid = n/2;
for(int i =0;i < mid;i++) {
for(int j = 0;j < mid;j++) {
arr11[i][j] = arr[i][j];
arr12[i][j] = arr[i][j+mid];
arr21[i][j] = arr[i+mid][j];
arr22[i][j] = arr[i+mid][j+mid];
}
}
}
其实算法导论书上说矩阵划分的时间复杂度是Θ(1),但是我没想出怎么实现。。
矩阵相加减代码
public static int[][] addOrSubtractionMatrix(int[][] arr1,int[][] arr2,int flag) {
int n = arr1.length;
int[][] addResultArr = new int[n][n];
for(int i = 0 ;i < n;i++) {
for (int j = 0; j < n; j++) {
if(flag == 1) {
addResultArr[i][j] = arr1[i][j] + arr2 [i][j];
}else {
addResultArr[i][j] = arr1[i][j] - arr2 [i][j];
}
}
}
return addResultArr;
}
矩阵结合代码
public static void combineMatrix(int[][] arr,int[][] arr11,int[][] arr12,int[][] arr21,int[][] arr22){
int n = arr.length;
int mid = n/2;
for(int i =0;i < mid;i++) {
for(int j = 0;j < mid;j++) {
arr[i][j] = arr11[i][j];
arr[i][j+mid] = arr12[i][j];
arr[i+mid][j] = arr21[i][j];
arr[i+mid][j+mid] = arr22[i][j];
}
}
}
最终代码
public static int[][] squareMatrixMultiply2(int[][] arr1,int [][] arr2){
int n = arr1.length;
int[][] arr3 = new int [n][n];
if(n==1) {
arr3[0][0]=arr1[0][0]*arr2[0][0];
}else {
int m = n/2;
int[][] arrA11 = new int[m][m];
int[][] arrA12 = new int[m][m];
int[][] arrA21 = new int[m][m];
int[][] arrA22 = new int[m][m];
int[][] arrB11 = new int[m][m];
int[][] arrB12 = new int[m][m];
int[][] arrB21 = new int[m][m];
int[][] arrB22 = new int[m][m];
int[][] arrC11 = new int[m][m];
int[][] arrC12 = new int[m][m];
int[][] arrC21 = new int[m][m];
int[][] arrC22 = new int[m][m];
divideMatrix(arr1,arrA11,arrA12,arrA21,arrA22);
divideMatrix(arr2,arrB11,arrB12,arrB21,arrB22);
divideMatrix(arr3,arrC11,arrC12,arrC21,arrC22);
arrC11 = addMatrix(squareMatrixMultiply2(arrA11,arrB11),squareMatrixMultiply2(arrA12,arrB21));
arrC12 = addMatrix(squareMatrixMultiply2(arrA11,arrB12),squareMatrixMultiply2(arrA12,arrB22));
arrC21 = addMatrix(squareMatrixMultiply2(arrA21,arrB11),squareMatrixMultiply2(arrA22,arrB21));
arrC22 = addMatrix(squareMatrixMultiply2(arrA21,arrB12),squareMatrixMultiply2(arrA22,arrB22));
combineMatrix(arr3,arrC11,arrC12,arrC21,arrC22);
}
return arr3;
}
递归实现的思路还是很明确的,采用分块矩阵的乘法方式,将一个大的矩阵经过一次次的划分再分别进行分块矩阵乘法
这种方法并没有比原来的那种方式快导数
原来n*n矩阵相乘的时间复杂度是T(n)
现在n/2*n/2矩阵相乘的时间复杂度是T(n/2)
又经历的8次乘法和4次加减法
所以递归的时间复杂度是8T(n/2)+Θ(n^2)
Strassen’s 矩阵乘法
public static int[][] strassenA(int[][] arrA,int[][] arrB) {
int n = arrA.length;
int[][] arr3 = new int [n][n];
if(n==1) {
arr3[0][0] = arrA[0][0]*arrB[0][0];
return arr3;
}
int m = n/2;
int[][] arrA11 = new int[m][m];
int[][] arrA12 = new int[m][m];
int[][] arrA21 = new int[m][m];
int[][] arrA22 = new int[m][m];
int[][] arrB11 = new int[m][m];
int[][] arrB12 = new int[m][m];
int[][] arrB21 = new int[m][m];
int[][] arrB22 = new int[m][m];
divideMatrix(arrA,arrA11,arrA12,arrA21,arrA22);
divideMatrix(arrB,arrB11,arrB12,arrB21,arrB22);
int[][] p1 = addOrSubtractionMatrix(strassenA(arrA11,arrB12),strassenA(arrA11,arrB22),0);
int[][] p2 = addOrSubtractionMatrix(strassenA(arrA11,arrB22),strassenA(arrA12,arrB22),1);
int[][] p3 = addOrSubtractionMatrix(strassenA(arrA21,arrB11),strassenA(arrA22,arrB11),1);
int[][] p4 = addOrSubtractionMatrix(strassenA(arrA22,arrB21),strassenA(arrA22,arrB11),0);
int[][] p5 = addOrSubtractionMatrix(addOrSubtractionMatrix(strassenA(arrA11,arrB11),strassenA(arrA11,arrB22),1),addOrSubtractionMatrix(strassenA(arrA22,arrB11),strassenA(arrA22,arrB22),1),1);
int[][] p6 = addOrSubtractionMatrix(addOrSubtractionMatrix(strassenA(arrA12,arrB21),strassenA(arrA12,arrB22),1),addOrSubtractionMatrix(strassenA(arrA22,arrB21),strassenA(arrA22,arrB22),1),0);
int[][] p7 = addOrSubtractionMatrix(addOrSubtractionMatrix(strassenA(arrA11,arrB11),strassenA(arrA11,arrB12),1),addOrSubtractionMatrix(strassenA(arrA21,arrB11),strassenA(arrA21,arrB12),1),0);
int[][] arrC11 = addOrSubtractionMatrix(addOrSubtractionMatrix(addOrSubtractionMatrix(p5,p4,1),p6,1),p2,0);
int[][] arrC12 = addOrSubtractionMatrix(p1,p2,1);
int[][] arrC21 = addOrSubtractionMatrix(p3,p4,1);;
int[][] arrC22 = addOrSubtractionMatrix(addOrSubtractionMatrix(p5,p1,1),addOrSubtractionMatrix(p3,p7,1),0);
combineMatrix(arr3,arrC11,arrC12,arrC21,arrC22);
return arr3;
}
减少了1次矩阵乘法,那除去的一次乘法由多次矩阵加法替代。
最终的时间复杂度为7T(n/2)+Θ(n^2),由此可见该方法是复杂度最低的