这里主要是给出实现方法,至于算法的介绍,可以参考@
金戈大王的介绍。
下面算法是有bug的,虽然已经对一些非n*n的矩阵做出了处理,但是没有完善,当计算非n*n的矩阵是会出现数据越界的异常的。所以两个相乘的矩阵必须是n*n的。
不啰嗦,下面直接贴出Java的实现方法:
//调用入口:
public static int[][] StrassenMulti(int a [][], int b[][])
{
int acow = a.length, acol = a[0].length, bcow = b.length, bcol = b[0].length;
if(acow!=acol || bcow != bcol || acow != bcow) return MatrixMulti(a,b);
if((acow &(acow -1 )) != 0) return MatrixMulti(a,b);//不是2的幂
//只有符合2的幂才满足Strassen算法使用条件
if(acow == 2) return MatrixMulti(a,b);
else
{
int[][] A11 = new int[acow/2][acow/2];
int[][] A12 = new int[acow/2][acow/2];
int[][] A21 = new int[acow/2][acow/2];
int[][] A22 = new int[acow/2][acow/2];
nnMatrixSplitTo4Block(a,A11,A12,A21,A22);
int[][] B11 = new int[acow/2][acow/2];
int[][] B12 = new int[acow/2][acow/2];
int[][] B21 = new int[acow/2][acow/2];
int[][] B22 = new int[acow/2][acow/2];
nnMatrixSplitTo4Block(b,B11,B12,B21,B22);
int [][]S1 = MatrixNeg(B12 ,B22);
int [][]S2 = MatrixPlus(A11 , A12);
int [][]S3 = MatrixPlus(A21 , A22);
int [][]S4 = MatrixNeg(B21 , B11);
int [][]S5 = MatrixPlus(A11 , A22);
int [][]S6 = MatrixPlus(B11 ,B22);
int [][]S7 = MatrixNeg(A12 , A22);
int [][]S8 = MatrixPlus(B21 , B22);
int [][]S9 = MatrixNeg(A11 , A21);
int [][]S10 = MatrixPlus(B11 , B12);
int [][]P1 = StrassenMulti(A11 , S1);
int [][]P2 = StrassenMulti(S2 , B22);
int [][]P3 = StrassenMulti(S3 , B11);
int [][]P4 = StrassenMulti(A22 , S4);
int [][]P5 = StrassenMulti(S5 ,S6);
int [][]P6 = StrassenMulti(S7 , S8);
int [][]P7 = StrassenMulti( S9 , S10);
int [][]C11 = MatrixPlus(MatrixNeg(MatrixPlus(P5, P4), P2) , P6);
int [][]C12 = MatrixPlus(P1 , P2);
int [][]C21 = MatrixPlus(P3, P4);
int [][]C22 = MatrixNeg(MatrixNeg(MatrixPlus(P5 , P1) , P3) , P7);
return MatrixBlockPlus(C11,C12,C21,C22);
}
}
//将矩阵分为四个子矩阵
public static void nnMatrixSplitTo4Block(int [][]src, int [][]A11, int [][]A12, int [][]A21, int [][]A22)
{
int n = src.length;
for(int i = 0; i<A11.length+A21.length; i++)
for(int j = 0; j<A11[0].length+A12[0].length; j++)
{
if(i<A11.length)
{
if(j<A11[0].length)
{
A11[i][j] = src[i][j];
}
else
{
A12[i][j-A11[0].length] = src[i][j];
}
}
else
{
if(j<A21[0].length)
{
A21[i-A11.length][j] = src[i][j];
}
else
{
A22[i-A11.length][j-A12[0].length] = src[i][j];
}
}
}
}
//将四个矩阵合并为一个矩阵
public static int[][] MatrixBlockPlus(int [][]A11, int [][]A12, int [][]A21, int [][]A22)
{
if(A11[0].length+A12[0].length != A21[0].length+A22[0].length || A11.length+A21.length != A12.length+A22.length) return null;
int result[][] = new int[A11.length+A21.length][A11[0].length+A12[0].length];
for(int i = 0; i<A11.length+A21.length; i++)
for(int j = 0; j<A11[0].length+A12[0].length; j++)
{
if(i<A11.length)
{
if(j<A11[0].length)
{
result[i][j] = A11[i][j];
}
else
{
result[i][j] = A12[i][j-A11[0].length];
}
}
else
{
if(j<A12[0].length)
{
result[i][j] = A21[i-A11.length][j];
}
else
{
result[i][j] = A22[i-A11.length][j-A12[0].length];
}
}
}
return result;
}
//矩阵减法
public static int[][] MatrixNeg(int a[][], int b[][])
{
int temp[][] = new int[b.length][b[0].length];
for(int i = 0 ;i<b.length; i++)
{
for(int j = 0; j < b[0].length; j++)
{
temp[i][j] = (-1) * b[i][j];
}
}
return MatrixPlus(a,temp);
}
//矩阵加法
public static int[][] MatrixPlus(int a[][], int b[][])
{
if(a[0].length != b[0].length || a.length != b.length) return null;
int result[][] = new int[a.length][a[0].length];
for(int i = 0; i<a.length; i++)
for(int j = 0; j<b.length; j++)
{
result[i][j] = a[i][j]+ b[i][j];
}
return result;
}
//矩阵乘法
public static int[][] MatrixMulti(int a[][], int b[][])
{
// a的列数不等于 b行数 //列数
int cow = a.length;//结果的行数
int col = b[0].length;//结果的列数
if(a[0].length != b.length ) return null;
else
{
int result[][] = new int[a.length][b[0].length];
for(int i = 0; i<cow; i++)
for(int j = 0; j<col; j++)
{
for(int k = 0; k<a[0].length; k++)
{
result[i][j] +=a[i][k] * b[k][j];
}
}
return result;
}
}