两个矩阵的乘法学过线性代数的都知道怎么求,一般来说复杂度为O(N^3).直接给出标准的算法
代码:
public class MartixMultiply {
public static int[][] multiply(int[][] a, int[][] b) {
int n = a.length;
int[][] c = new int[n][n];
for (int i = 0; i < n; i++)
// 初始化
for (int j = 0; j < n; j++)
c[i][j] = 0;
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
for (int k = 0; k < n; k++)
c[i][j] += a[i][k] * b[k][j];
return c;
}
public static void main(String[] args) {
int[][] a = { { 1, 2 }, { 3, 4 } };
int[][] b = { { 3, 4 }, { 7, 2 } };
int[][] c = multiply(a, b);
System.out.println(c[0][0] + " " + c[0][1] + " " + c[1][0] + " "
+ c[1][1]);
}
}
Strassen提出了算法打破了O(N^3)的屏障.用到分治算法,把矩阵分为4块.
其中
可以得到递推关系T(N)=7T(N/2)+O(N²),依据主定理得到解T(N)=O(N^2.81).
这儿不做出证明,显然这用到了分治法的思想
代码:
public class MartixMultiply {
public static int[][] StrassenMultiply(int[][] a, int[][] b) {
int[][] result = new int[a.length][b.length];
if (a.length == 2)
return multiply(a, b);// 如果是2阶的 就结束递归 用传统方法
// a的四个子矩阵
int[][] A00 = divide(a, 1);
int[][] A01 = divide(a, 2);
int[][] A10 = divide(a, 3);
int[][] A11 = divide(a, 4);
// b的四个子矩阵
int[][] B00 = divide(b, 1);
int[][] B01 = divide(b, 2);
int[][] B10 = divide(b, 3);
int[][] B11 = divide(b, 4);
int[][] m1 = StrassenMultiply(addArrays(A00, A11), addArrays(B00, B11));
int[][] m2 = StrassenMultiply(addArrays(A10, A11), B00);
int[][] m3 = StrassenMultiply(A00, subArrays(B01, B11));
int[][] m4 = StrassenMultiply(A11, subArrays(B10, B00));
int[][] m5 = StrassenMultiply(addArrays(A00, A01), B11);
int[][] m6 = StrassenMultiply(subArrays(A10, A00), addArrays(B00, B01));
int[][] m7 = StrassenMultiply(subArrays(A01, A11), addArrays(B10, B11));
int[][] C00 = addArrays(m7, subArrays(addArrays(m1, m4), m5));// m1+m4-m5+m7
int[][] C01 = addArrays(m3, m5); // m3+m5
int[][] C10 = addArrays(m2, m4); // m2+m4
int[][] C11 = addArrays(m6, subArrays(addArrays(m1, m3), m2));// m1+m3-m2+m6
// 将四个矩阵合并起来
Merge(result, C00, 1);
Merge(result, C01, 2);
Merge(result, C10, 3);
Merge(result, C11, 4);
return result;
}
// /分割得到子矩阵
private static int[][] divide(int[][] a, int flag) {
int[][] result = new int[a.length / 2][a.length / 2];
switch (flag) {
case 1:
for (int i = 0; i < a.length / 2; i++)
for (int j = 0; j < a.length / 2; j++)
result[i][j] = a[i][j];
break;
case 2:
for (int i = 0; i < a.length / 2; i++)
for (int j = a.length / 2; j < a.length; j++)
result[i][j - a.length / 2] = a[i][j];
break;
case 3:
for (int i = a.length / 2; i < a.length; i++)
for (int j = 0; j < a.length / 2; j++)
result[i - a.length / 2][j] = a[i][j];
break;
case 4:
for (int i = a.length / 2; i < a.length; i++)
for (int j = a.length / 2; j < a.length; j++)
result[i - a.length / 2][j - a.length / 2] = a[i][j];
break;
}
return result;
}
// 矩阵加法
private static int[][] addArrays(int[][] a, int[][] b) {
int[][] result = new int[a.length][a.length];
for (int i = 0; i < result.length; i++) {
for (int j = 0; j < result.length; j++) {
result[i][j] = a[i][j] + b[i][j];
}
}
return result;
}
// 矩阵减法
private static int[][] subArrays(int[][] a, int[][] b) {
int[][] result = new int[a.length][a.length];
for (int i = 0; i < result.length; i++) {
for (int j = 0; j < result.length; j++) {
result[i][j] = a[i][j] - b[i][j];
}
}
return result;
}
// 将b复制到a的指定位置
private static void Merge(int[][] a, int[][] b, int flag) {
switch (flag) {
case 1:
for (int i = 0; i < a.length / 2; i++)
for (int j = 0; j < a.length / 2; j++)
a[i][j] = b[i][j];
break;
case 2:
for (int i = 0; i < a.length / 2; i++)
for (int j = a.length / 2; j < a.length; j++)
a[i][j] = b[i][j - a.length / 2];
break;
case 3:
for (int i = a.length / 2; i < a.length; i++)
for (int j = 0; j < a.length / 2; j++)
a[i][j] = b[i - a.length / 2][j];
break;
case 4:
for (int i = a.length / 2; i < a.length; i++)
for (int j = a.length / 2; j < a.length; j++)
a[i][j] = b[i - a.length / 2][j - a.length / 2];
break;
}
}
// 常规做法
public static int[][] multiply(int[][] a, int[][] b) {
int n = a.length;
int[][] c = new int[n][n];
for (int i = 0; i < n; i++)
// Initialization
for (int j = 0; j < n; j++)
c[i][j] = 0;
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
for (int k = 0; k < n; k++)
c[i][j] += a[i][k] * b[k][j];
return c;
}
public static void main(String[] args) {
int[][] a = { { 1, 2, 6, 7 }, { 3, 4, 5, 4 }, { 5, 8, 3, 8 },
{ -6, 4, 3, 9 } };
int[][] b = { { 3, 4, 9, 0 }, { 7, 2, -5, -6 }, { 0, 7, -4, 6 },
{ -6, 3, -5, 4 } };
int[][] c = multiply(a, b);
System.out.println(c[0][0] + " " + c[0][1] + " " + c[1][0] + " "
+ c[1][1]);
}
}