矩阵乘法公式
//矩阵乘法运算
public class MatrixMul {
public static void main(String[] args){
int A[][]={{1,2,3},{4,5,6}};
int B[][]={{1,4},{2,5},{3,6}};
//计算A*B则输出的矩阵就是C[A的行数][B的列数]
int C[][]=new int[A.length][B[0].length];
for(int i=0;i<A.length;i++){
for(int j=0;j<B[0].length;j++){
int t=0;
//计算矩阵行乘类的和
for(int k=0;k<A[0].length;k++){ //这里k可以小于A的列数或者B的行数
t=A[i][k]*B[k][j]+t;
}
C[i][j]=t;
}
}
for(int i=0;i<C.length;i++){
for(int j=0;j<C[0].length;j++){
System.out.print(C[i][j]+" ");
}
System.out.println();
}
}
}
Strassen算法实现
虽然代码很长,但很好理解
参考文章
http://m.mamicode.com/info-detail-673908.html 写的很详细!
http://blog.csdn.net/tanlingyun/article/details/1591414 作者的代码有点问题!
算法的关键
class Matrix {
public int[][] m = new int[32][32];
}
//Strassen矩阵乘法
public class Strassen {
//判断阶数是否是2的n次方
public int judge(int p) {
int mark = 0;
if (p < 1) { //矩阵的阶数必须大于等于1
return mark;
}
while (p % 2 == 0) {
p /= 2;
}
if (p == 1) {
mark = 1;
}
return mark;
}
//a b 两个矩阵加法
public Matrix addMatrix(Matrix a, Matrix b, int n) /*矩阵加法方法*/ {
Matrix c = new Matrix();
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
c.m[i][j] = a.m[i][j] + b.m[i][j];
}
}
return c;
}
//a b 两个矩阵减法,就改了个加号和名字
public Matrix subMatrix(Matrix a, Matrix b, int n) /*矩阵加法方法*/ {
Matrix c = new Matrix();
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
c.m[i][j] = a.m[i][j] - b.m[i][j];
}
}
return c;
}
//2阶矩阵乘法和上一个算法差不多
public Matrix MatrixMultiply(Matrix a, Matrix b) {
Matrix c = new Matrix();
for (int i = 1; i <= 2; i++) {
for (int j = 1; j <= 2; j++) {
int t = 0;
for (int k = 1; k <= 2; k++) {
t = a.m[i][k] * b.m[k][j] + t;
}
c.m[i][j] = t;
}
}
return c;
}
//分解矩阵x为4块
public void decompose(Matrix x, Matrix x11, Matrix x12, Matrix x21, Matrix x22, int n)/*分解矩阵方法*/ {
int i, j;
for (i = 1; i <= n; i++) {
for (j = 1; j <= n; j++) {
x11.m[i][j] = x.m[i][j];
x12.m[i][j] = x.m[i][j + n];
x21.m[i][j] = x.m[i + n][j];
x22.m[i][j] = x.m[i + n][j + n];
}
}
}
//合并算法
public Matrix merge(Matrix c11, Matrix c12, Matrix c21, Matrix c22, int n)/*合并矩阵方法*/ {
int i, j;
Matrix c = new Matrix();
for (i = 1; i <= n; i++) {
for (j = 1; j <= n; j++) {
c.m[i][j] = c11.m[i][j];
c.m[i][j + n] = c12.m[i][j];
c.m[i + n][j] = c21.m[i][j];
c.m[i + n][j + n] = c22.m[i][j];
}
}
return c;
}
public Matrix strassen(Matrix a, Matrix b, int n) {
int halfsize = n / 2;
Matrix a11, a12, a21, a22, b11, b12, b21, b22, c11, c12, c21, c22, c;
//分解a矩阵
a11 = new Matrix();
a12 = new Matrix();
a21 = new Matrix();
a22 = new Matrix();
//分解b矩阵
b11 = new Matrix();
b12 = new Matrix();
b21 = new Matrix();
b22 = new Matrix();
//分解c矩阵
c11 = new Matrix();
c12 = new Matrix();
c21 = new Matrix();
c22 = new Matrix();
//输出结果c
c = new Matrix();
//7个新矩阵
Matrix m1, m2, m3, m4, m5, m6, m7;
if (n == 2) {
c = MatrixMultiply(a, b);
return c;
} else {
//分解矩阵a b c
decompose(a, a11, a12, a21, a22, halfsize);
decompose(b, b11, b12, b21, b22, halfsize);
decompose(c, c11, c12, c21, c22, halfsize);
//计算超级烦的m1--m7,注意眼睛别看花了
m1 = strassen(addMatrix(a11, a22, halfsize), addMatrix(b11, b22, halfsize), halfsize);
m2 = strassen(addMatrix(a11, a22, halfsize), b11, halfsize);
m3 = strassen(a11, subMatrix(b12, b22, halfsize), halfsize);
m4 = strassen(a22, subMatrix(b21, b11, halfsize), halfsize);
m5 = strassen(addMatrix(a11, a12, halfsize), b22, halfsize);
m6 = strassen(subMatrix(a21, a11, halfsize), addMatrix(b11, b12, halfsize), halfsize);
m7 = strassen(subMatrix(a12, a22, halfsize), addMatrix(b21, b22, halfsize), halfsize);
//计算c11-c22
c11 = addMatrix(subMatrix(addMatrix(m1, m4, halfsize), m5, halfsize), m7, halfsize);
c12 = addMatrix(m3, m5, halfsize);
c21 = addMatrix(m2, m4, halfsize);
c22 = addMatrix(subMatrix(addMatrix(m1, m3, halfsize), m2, halfsize), m6, halfsize);
//合并c11-c22到c中
c = merge(c11, c12, c21, c22, halfsize);
return c;
}
}
public static void main(String args[]) {
Strassen stra = new Strassen();
Scanner scan = new Scanner(System.in);
System.out.println("输入矩阵的阶数:");
int p = scan.nextInt();
if (stra.judge(p) == 1) {
Matrix a = new Matrix();
Matrix b = new Matrix();
Matrix c = new Matrix();
System.out.println("输入矩阵A:");
for (int i = 1; i <= p; i++) {
for (int j = 1; j <= p; j++) {
a.m[i][j] = scan.nextInt();
}
}
System.out.println("输入矩阵B:");
for (int i = 1; i <= p; i++) {
for (int j = 1; j <= p; j++) {
b.m[i][j] = scan.nextInt();
}
}
System.out.println("矩阵C:");
//当矩阵阶数为1时,单独处理
if (p == 1) {
c.m[1][1] = a.m[1][1] * b.m[1][1];
} else {
c = stra.strassen(a, b, p);
}
for (int i = 1; i <= p; i++) {
for (int j = 1; j <= p; j++) {
System.out.print(c.m[i][j] + " ");
}
System.out.println();
}
} else {
System.out.println(p + "不是2的n次方!");
}
}
}
如果有错,请联系我,谢谢。