Conquer-Divide的经典例子之Strassen算法解决大型矩阵的相乘

通过汉诺塔问题理解递归的精髓中我讲解了怎么把一个复杂的问题一步步recursively划分了成简单显而易见的小问题。其实这个解决问题的思路就是算法中常用的divide and conquer, 这篇日志通过解决矩阵的乘法,来了解另外一个基本divide and conque思想的strassen算法。

矩阵A乘以B等于X, 则Xij = 
注意左乘右乘的区别,AB 与BA是不同的。
如果r = 1, 直接就是两个数的相乘。
如果r = 2, 例如
X = 
[ 1, 2; 
  3, 4];
Y = 
[ 2, 3;
 4, 5];
R = XY的计算十分简单,但是如果r很大,耗时是O(r^3)。为了简化,可以把X, Y各自划分成2X2的矩阵,每一个元素其实是有n/2行的矩阵
(注:这里仅讲解行数等于列数的情况。)

X = 
[A, B;
C, D];

Y = 
[E, F;
G, H]

所以XY =[
AE+BG, AF+BH;
CE+DG, CF+DH]

Strassen引入seven magic product 分别是P1, P2, P3 ,P4, P5, P6, P7
P1 = A(F-H)
P2 = (A+B)H
P3 = (C+D)E
P4 = D(G-E)
P5 = (A+D)(E+H)
P6 = (B-D)(G+H)
P7 = (A-C)(E+F)

这样XY = 
[P5+P4-P2+P6, P1+P2;
P3+P4, P1+P5-P3-P7]

然后通过递归的策略计算矩阵的相乘,递归的出口是n = 1.

关键点就是这些,附上代码吧。

//multiply matrix multiplication
import java.util.Scanner;
public class Strassen{
    public Strassen(){}


    /** split a parent matrix into child matrics8*/
    public static void split(int[][] P, int[][] C, int iB, int jB){
        for(int i1=0, i2 = iB; i1<C.length; i1++, i2++)
            for(int j1=0, j2=jB; j1<C.length; j1++, j2++)
                C[i1][j1] = P[i2][j2];
    }


    /**join child matric into parent matrix*/
    public static void join(int[][] C, int[][] P, int iB, int jB){
        for(int i1=0, i2 = iB; i1<C.length; i1++, i2++)
            for(int j1=0, j2=jB; j1<C.length; j1++, j2++)
                P[i2][j2]=C[i1][j1]; 
    }


    /**add two matrics into one*/
    public static int[][] add(int[][] A, int[][] B){
        //A and B has the same dimension
        int n = A.length;
        int[][] C = new int[n][n];
        for (int i=0; i<n; i++)
            for(int j=0; j<n; j++)
                C[i][j] = A[i][j] + B[i][j];
                
        return C;        
    }




    //subtract one matric by another
    public static int[][] sub(int[][] A, int[][] B){
        //A and B has the same dimension
        int n = A.length;
        int[][] C = new int[n][n];
        for (int i=0; i<n; i++)
            for(int j=0; j<n; j++)
                C[i][j] = A[i][j] - B[i][j];
        return C;   
    }


    //Multiply matrix
    public static int[][] multiply(int[][] A, int[][] B){
        int n = A.length;
        int[][] R = new int[n][n];


        /**exit*/
        if(n==1)
            R[0][0] = A[0][0]+B[0][0];


        else{
            //divide A into 4 submatrix
            int[][] A11 = new int[n/2][n/2];
            int[][] A12 = new int[n/2][n/2];
            int[][] A21 = new int[n/2][n/2];
            int[][] A22 = new int[n/2][n/2];


            split(A, A11, 0, 0);
            split(A, A12, 0, n/2);
            split(A, A21, n/2, 0);
            split(A, A22, n/2, n/2);


            //divide B into 4 submatric
            int[][] B11 = new int[n/2][n/2];
            int[][] B12 = new int[n/2][n/2];
            int[][] B21 = new int[n/2][n/2];
            int[][] B22 = new int[n/2][n/2];


            split(B, B11, 0, 0);
            split(B, B12, 0, n/2);
            split(B, B21, n/2, 0);
            split(B, B22, n/2, n/2);


            //seven magic products
            int[][] P1 = multiply(A11, sub(B12, B22));
            int[][] P2 = multiply(add(A11,A12), B22);
            int[][] P3 = multiply(add(A21, A22), B11);
            int[][] P4 = multiply(A22, sub(B21, B11));
            int[][] P5 = multiply(add(A11, A22), add(B11, B22));
            int[][] P6 = multiply(sub(A12, A22), add(B21, B22));
            int[][] P7 = multiply(sub(A11, A21), add(B11, B12));




            //new 4 submatrix
            int[][] R11 = add(add(P5, sub(P4, P2)), P6);
            int[][] R12 = add(P1, P2);
            int[][] R21 = add(P3, P4);
            int[][] R22 = sub(sub(add(P1, P5), P3), P7);


            //joint together
            join(R11, R, 0, 0);
            join(R12, R, 0, n/2);
            join(R21, R, n/2, 0);
            join(R22, R, n/2, n/2);

        }
        return R;
    }


    //main 
    public static void main(String[] args){
        
        Scanner scan = new Scanner(System.in);
        System.out.println("Strassen Multiplication Algorithm Test\n");
        Strassen s = new Strassen();
 


        System.out.println("Fetch the matric A and B...");
        int N = scan.nextInt();
        int[][] A = new int[N][N];
        int[][] B = new int[N][N];


        for (int i = 0; i < N; i++)
            for (int j = 0; j < N; j++)
                A[i][j] = scan.nextInt();


        for (int i = 0; i < N; i++)
            for (int j = 0; j < N; j++)
                B[i][j] = scan.nextInt();


        System.out.println("Fetch Completed!");
 
        int[][] C = s.multiply(A, B);
        
        System.out.println("\nmatrices A = ");
        for (int i = 0; i < N; i++){
            for (int j = 0; j < N; j++)
                System.out.print(A[i][j] +" ");
            System.out.println();
        }


        System.out.println("\nmatrices B =");
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++)
                System.out.print(B[i][j] +" ");
            System.out.println();
        }
 
        System.out.println("\nProduct of matrices A and  B  = ");
        for (int i = 0; i < N; i++)
        {
            for (int j = 0; j < N; j++)
                System.out.print(C[i][j] +" ");
            System.out.println();
        }
    }
}



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值