矩阵乘法-java

矩阵乘法公式
这里写图片描述

//矩阵乘法运算
public class MatrixMul {
    public static void main(String[] args){
        int A[][]={{1,2,3},{4,5,6}};
        int B[][]={{1,4},{2,5},{3,6}};
        //计算A*B则输出的矩阵就是C[A的行数][B的列数]
        int C[][]=new int[A.length][B[0].length];
       for(int i=0;i<A.length;i++){
           for(int j=0;j<B[0].length;j++){
               int t=0;
               //计算矩阵行乘类的和
               for(int k=0;k<A[0].length;k++){ //这里k可以小于A的列数或者B的行数
                   t=A[i][k]*B[k][j]+t;
               }
               C[i][j]=t;
           }
       }
       for(int i=0;i<C.length;i++){
           for(int j=0;j<C[0].length;j++){
               System.out.print(C[i][j]+" ");
           }
           System.out.println();
       }
    }
}

这里写图片描述
Strassen算法实现
虽然代码很长,但很好理解
参考文章

http://m.mamicode.com/info-detail-673908.html 写的很详细!
http://blog.csdn.net/tanlingyun/article/details/1591414 作者的代码有点问题!

算法的关键
这里写图片描述

class Matrix {

    public int[][] m = new int[32][32];
}
//Strassen矩阵乘法

public class Strassen {

    //判断阶数是否是2的n次方
    public int judge(int p) {
        int mark = 0;
        if (p < 1) {  //矩阵的阶数必须大于等于1
            return mark;
        }
        while (p % 2 == 0) {
            p /= 2;
        }
        if (p == 1) {
            mark = 1;
        }
        return mark;
    }

    //a b 两个矩阵加法 
    public Matrix addMatrix(Matrix a, Matrix b, int n) /*矩阵加法方法*/ {
        Matrix c = new Matrix();
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                c.m[i][j] = a.m[i][j] + b.m[i][j];
            }
        }
        return c;
    }
    //a b 两个矩阵减法,就改了个加号和名字

    public Matrix subMatrix(Matrix a, Matrix b, int n) /*矩阵加法方法*/ {
        Matrix c = new Matrix();
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                c.m[i][j] = a.m[i][j] - b.m[i][j];
            }
        }
        return c;
    }

    //2阶矩阵乘法和上一个算法差不多
    public Matrix MatrixMultiply(Matrix a, Matrix b) {
        Matrix c = new Matrix();
        for (int i = 1; i <= 2; i++) {
            for (int j = 1; j <= 2; j++) {
                int t = 0;
                for (int k = 1; k <= 2; k++) {
                    t = a.m[i][k] * b.m[k][j] + t;
                }
                c.m[i][j] = t;
            }
        }
        return c;
    }

    //分解矩阵x为4块
    public void decompose(Matrix x, Matrix x11, Matrix x12, Matrix x21, Matrix x22, int n)/*分解矩阵方法*/ {
        int i, j;
        for (i = 1; i <= n; i++) {
            for (j = 1; j <= n; j++) {
                x11.m[i][j] = x.m[i][j];
                x12.m[i][j] = x.m[i][j + n];
                x21.m[i][j] = x.m[i + n][j];
                x22.m[i][j] = x.m[i + n][j + n];
            }
        }
    }

    //合并算法
    public Matrix merge(Matrix c11, Matrix c12, Matrix c21, Matrix c22, int n)/*合并矩阵方法*/ {
        int i, j;
        Matrix c = new Matrix();
        for (i = 1; i <= n; i++) {
            for (j = 1; j <= n; j++) {
                c.m[i][j] = c11.m[i][j];
                c.m[i][j + n] = c12.m[i][j];
                c.m[i + n][j] = c21.m[i][j];
                c.m[i + n][j + n] = c22.m[i][j];
            }
        }
        return c;
    }

    public Matrix strassen(Matrix a, Matrix b, int n) {
        int halfsize = n / 2;
        Matrix a11, a12, a21, a22, b11, b12, b21, b22, c11, c12, c21, c22, c;
        //分解a矩阵
        a11 = new Matrix();
        a12 = new Matrix();
        a21 = new Matrix();
        a22 = new Matrix();
        //分解b矩阵
        b11 = new Matrix();
        b12 = new Matrix();
        b21 = new Matrix();
        b22 = new Matrix();
        //分解c矩阵
        c11 = new Matrix();
        c12 = new Matrix();
        c21 = new Matrix();
        c22 = new Matrix();
        //输出结果c
        c = new Matrix();
        //7个新矩阵
        Matrix m1, m2, m3, m4, m5, m6, m7;

        if (n == 2) {
            c = MatrixMultiply(a, b);
            return c;
        } else {
            //分解矩阵a b c
            decompose(a, a11, a12, a21, a22, halfsize);
            decompose(b, b11, b12, b21, b22, halfsize);
            decompose(c, c11, c12, c21, c22, halfsize);
            //计算超级烦的m1--m7,注意眼睛别看花了
            m1 = strassen(addMatrix(a11, a22, halfsize), addMatrix(b11, b22, halfsize), halfsize);
            m2 = strassen(addMatrix(a11, a22, halfsize), b11, halfsize);
            m3 = strassen(a11, subMatrix(b12, b22, halfsize), halfsize);
            m4 = strassen(a22, subMatrix(b21, b11, halfsize), halfsize);
            m5 = strassen(addMatrix(a11, a12, halfsize), b22, halfsize);
            m6 = strassen(subMatrix(a21, a11, halfsize), addMatrix(b11, b12, halfsize), halfsize);
            m7 = strassen(subMatrix(a12, a22, halfsize), addMatrix(b21, b22, halfsize), halfsize);
            //计算c11-c22
            c11 = addMatrix(subMatrix(addMatrix(m1, m4, halfsize), m5, halfsize), m7, halfsize);
            c12 = addMatrix(m3, m5, halfsize);
            c21 = addMatrix(m2, m4, halfsize);
            c22 = addMatrix(subMatrix(addMatrix(m1, m3, halfsize), m2, halfsize), m6, halfsize);
            //合并c11-c22到c中
            c = merge(c11, c12, c21, c22, halfsize);
            return c;
        }

    }

    public static void main(String args[]) {
        Strassen stra = new Strassen();
        Scanner scan = new Scanner(System.in);
        System.out.println("输入矩阵的阶数:");
        int p = scan.nextInt();
        if (stra.judge(p) == 1) {
            Matrix a = new Matrix();
            Matrix b = new Matrix();
            Matrix c = new Matrix();
            System.out.println("输入矩阵A:");
            for (int i = 1; i <= p; i++) {
                for (int j = 1; j <= p; j++) {
                    a.m[i][j] = scan.nextInt();
                }
            }
            System.out.println("输入矩阵B:");
            for (int i = 1; i <= p; i++) {
                for (int j = 1; j <= p; j++) {
                    b.m[i][j] = scan.nextInt();
                }
            }
            System.out.println("矩阵C:");
            //当矩阵阶数为1时,单独处理
            if (p == 1) {
                c.m[1][1] = a.m[1][1] * b.m[1][1];
            } else {
                c = stra.strassen(a, b, p);
            }
            for (int i = 1; i <= p; i++) {
                for (int j = 1; j <= p; j++) {
                    System.out.print(c.m[i][j] + " ");
                }
                System.out.println();
            }
        } else {
            System.out.println(p + "不是2的n次方!");
        }
    }

}

这里写图片描述

如果有错,请联系我,谢谢。

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值