算法导论学习笔记—Strassen算法的Java实现

Strassen算法

    Strassen算法的核心思想是令递归树稍微不那么茂盛,相比于简单的“分而治之”的矩阵递归计算,其递归的分支由8条减少到7条。其时间复杂度为O(n的lg7次方)。虽然,它的算法中需要新增10个(n/2 * n/2)的中间矩阵S1-S10。每次子矩阵的加减运算会增加O(n平方/4)的时间消耗,所以代码在执行S1-S10时,这部分的时间复杂度为10*O(n平方/4),但是相比之下,Strassen算法减少了一次递归,所以时间复杂度上会减少。

   下面来看简单“分而治之”矩阵乘法和Strassen算法的对比:

(1) 简单“分而治之”矩阵乘法

package com.oracle.ThirdCharpter;
/**
 * 写一个简单的“分而治之的矩阵乘法”,即A = [A11 A12    B=[B11 B12   C=[A11*B11+A12*B21  A11*B12+A12*B22
 *                                        A21 A22]      B21 B22]     A21*B11+A22*B21  A21*B12+A22*B22]
 * 经过算法分析发现,其时间复杂度依然还是O(n的3次方)
 * @author zhegao
 *
 */
public class Practice1_4 {
    public int[][] matrix_multiply(int[][] a,int[][] b) {        
        if(a.length==1) {
            return new int[][] {{a[0][0]*b[0][0]}};
        }else {
            int[][] A11 = partition(a,1);
            int[][] B11 = partition(b,1);
            int[][] A12 = partition(a,2);
            int[][] B12 = partition(b,2);
            int[][] A21 = partition(a,3);
            int[][] B21 = partition(b,3);
            int[][] A22 = partition(a,4);
            int[][] B22 = partition(b,4);

            //进行加法运算
            int[][] C11 = matrixAdd(matrix_multiply(A11,B11),matrix_multiply(A12,B21));
            int[][] C12 = matrixAdd(matrix_multiply(A11,B12),matrix_multiply(A12,B22));
            int[][] C21 = matrixAdd(matrix_multiply(A21,B11),matrix_multiply(A22,B21));
            int[][] C22 = matrixAdd(matrix_multiply(A21,B12),matrix_multiply(A22,B22));
            int[][] C = merge(C11,C12,C21,C22);
            return C;
        }
        
    }
    //拆分矩阵,得到四个子矩阵,把不同位置的子矩阵标记成1,2,3,4。1——左上;2——右上;3——左下;4——右下
    public int[][] partition(int[][] arr,int index) {
        int len = arr.length;
        int[][] result = new int[len/2][len/2];
        switch(index) {
        case 1:
            for(int i=0;i<len/2;i++) {
                for(int j=0;j<len/2;j++) {                    
                    result[i][j]=arr[i][j];        
                    //System.out.println(result[i][j]);
                }
            };
            break;

        case 2:
            for(int i=0;i<len/2;i++) {
                for(int j=len/2;j<len;j++) {                
                    result[i][j-len/2]=arr[i][j];                        
                }
            };
            break;
        case 3:
            for(int i=len/2;i<len;i++) {                
                for(int j=0;j<len/2;j++) {                                    
                    result[i-len/2][j] = arr[i][j];
                }


            };
            break;
        case 4:
            for(int i=len/2;i<len;i++) {
                for(int j=len/2;j<len;j++) {                    
                    result[i-len/2][j-len/2]=arr[i][j];
                }
            };
        }
        return result;
    }
    //矩阵的加运算
    public int[][] matrixAdd(int[][] a,int[][] b){
        int[][] result = new int[a.length][a.length];
        for(int i=0;i<a.length;i++) {
            for(int j=0;j<a.length;j++) {
                result[i][j]=a[i][j]+b[i][j];
            }
        }
        return result;
    }
    public  void display(int[][] arr) {
        System.out.print("[");
        for(int i=0;i<arr.length;i++) {
            for(int j=0;j<arr.length;j++) {
                if(j==arr.length-1) {
                    System.out.print(arr[i][j]);
                }else {
                    System.out.print(arr[i][j]+ " ");
                }
            }
            System.out.print("]");
            System.out.print("\n");
        }    

    }

    //将四个子矩阵合并成一个整体的大矩阵
    public int[][] merge(int[][] a1,int[][] a2,int[][] a3,int[][] a4){
        int len = a1.length;
        int[][] result = new int[len*2][len*2];
        for(int i=0;i<result.length;i++) {
            if(i<len) {
                for(int j=0;j<result.length;j++) {
                    if(j<len) {
                        result[i][j]=a1[i][j];
                    }else {
                        result[i][j] = a2[i][j-len];
                    }
                }
            }else {
                for(int j=0;j<result.length;j++) {
                    if(j<len) {
                        result[i][j]=a3[i-len][j];
                    }else {
                        result[i][j] = a4[i-len][j-len];
                    }
                }
            }
        }
        return result;
    }
    public static void main(String[] args) {
        int[][] arr = new int[][] {{1,2,3,4},{5,6,7,8},{9,10,11,12},{13,14,15,16}};
        Practice1_4 prac = new Practice1_4();
        int[][] result1 = prac.partition(arr, 1);
        int[][] result2 = prac.partition(arr, 2);
        int[][] result3 = prac.partition(arr, 3);
        int[][] result4 = prac.partition(arr, 4);

      //测试矩阵分离方法

        prac.display(result1);
        prac.display(result2);
        prac.display(result3);
        prac.display(result4);
   
        //测试矩阵的合并方法
        int[][] merge = prac.merge(result1, result2, result3, result4);
        prac.display(merge);
        
        //测试分而治之的矩阵乘法
        int[][] a1 = new int[][] {{1,2,3,4},{4,3,2,1},{0,1,2,3},{3,2,1,0}};
        int[][] a2 = new int[][] {{1,2,3,4},{4,3,2,1},{0,1,2,3},{3,2,1,0}};
        prac.display(prac.matrix_multiply(a1, a2));
    }

}

(2)Strassen算法

package com.oracle.ThirdCharpter;

/**
 * 使用Strassen算法进行矩阵乘法
 * 相比于“分而治之”的矩阵乘法,Stassen算法的递归分支只有7条,所以其时间复杂度为O(log2 7)
 *
 * 分析:Strassen算法相比传统”分而治之“的算法,它的递归分支只有7条,P1-P7。
 * @author zhegao
 *
 */
public class Practice1_5 {

    public int[][] matrix_multiply(int[][] a,int[][] b) {        
        if(a.length==1) {
            return new int[][] {{a[0][0]*b[0][0]}};
        }else {
            int[][] A11 = partition(a,1);
            int[][] B11 = partition(b,1);
            int[][] A12 = partition(a,2);
            int[][] B12 = partition(b,2);
            int[][] A21 = partition(a,3);
            int[][] B21 = partition(b,3);
            int[][] A22 = partition(a,4);
            int[][] B22 = partition(b,4);
            
            //计算S1-S10的中间矩阵
            int[][] S1 = matrixSubstract(B12,B22);
            int[][] S2 = matrixAdd(A11,A12);
            int[][] S3 = matrixAdd(A21,A22);
            int[][] S4 = matrixSubstract(B21,B11);
            int[][] S5 = matrixAdd(A11,A22);
            int[][] S6 = matrixAdd(B11,B22);
            int[][] S7 = matrixSubstract(A12,A22);
            int[][] S8 = matrixAdd(B21,B22);
            int[][] S9 = matrixSubstract(A11,A21);
            int[][] S10 = matrixAdd(B11,B12);
            
            //计算P1-P7的几个递归矩阵
            int[][] P1 = matrix_multiply(A11,S1);
            int[][] P2 = matrix_multiply(S2,B22);
            int[][] P3 = matrix_multiply(S3,B11);
            int[][] P4 = matrix_multiply(A22,S4);
            int[][] P5 = matrix_multiply(S5,S6);
            int[][] P6 = matrix_multiply(S7,S8);
            int[][] P7 = matrix_multiply(S9,S10);
            
            //进行加减运算
            int[][] C11 = matrixAdd(matrixSubstract(matrixAdd(P5,P4),P2),P6);
            int[][] C12 = matrixAdd(P1,P2);
            int[][] C21 = matrixAdd(P3,P4);
            int[][] C22 = matrixSubstract(matrixSubstract(matrixAdd(P5,P1),P3),P7);
            
            //合并各个子矩阵
            int[][] C = merge(C11,C12,C21,C22);
            return C;
        }

    }

    //拆分矩阵,得到四个子矩阵,把不同位置的子矩阵标记成1,2,3,4。1——左上;2——右上;3——左下;4——右下
    public int[][] partition(int[][] arr,int index) {
        int len = arr.length;
        int[][] result = new int[len/2][len/2];
        switch(index) {
        case 1:
            for(int i=0;i<len/2;i++) {
                for(int j=0;j<len/2;j++) {                    
                    result[i][j]=arr[i][j];        
                    //System.out.println(result[i][j]);
                }
            };
            break;

        case 2:
            for(int i=0;i<len/2;i++) {
                for(int j=len/2;j<len;j++) {                
                    result[i][j-len/2]=arr[i][j];                        
                }
            };
            break;
        case 3:
            for(int i=len/2;i<len;i++) {                
                for(int j=0;j<len/2;j++) {                                    
                    result[i-len/2][j] = arr[i][j];
                }


            };
            break;
        case 4:
            for(int i=len/2;i<len;i++) {
                for(int j=len/2;j<len;j++) {                    
                    result[i-len/2][j-len/2]=arr[i][j];
                }
            };
        }
        return result;
    }
    
    //矩阵的加运算
    public int[][] matrixAdd(int[][] a,int[][] b){
        int[][] result = new int[a.length][a.length];
        for(int i=0;i<a.length;i++) {
            for(int j=0;j<a.length;j++) {
                result[i][j]=a[i][j]+b[i][j];
            }
        }
        return result;
    }
    
    //矩阵的减运算
    public int[][] matrixSubstract(int[][] a,int[][] b){
        int[][] result = new int[a.length][a.length];
        for(int i=0;i<a.length;i++) {
            for(int j=0;j<a.length;j++) {
                result[i][j]=a[i][j]-b[i][j];
            }
        }
        return result;
    }
    public  void display(int[][] arr) {
        System.out.print("[");
        for(int i=0;i<arr.length;i++) {
            for(int j=0;j<arr.length;j++) {
                if(j==arr.length-1) {
                    System.out.print(arr[i][j]);
                }else {
                    System.out.print(arr[i][j]+ " ");
                }
            }
            System.out.print("]");
            System.out.print("\n");
        }    
    }
    
    //将四个子矩阵合并成一个整体的大矩阵
        public int[][] merge(int[][] a1,int[][] a2,int[][] a3,int[][] a4){
            int len = a1.length;
            int[][] result = new int[len*2][len*2];
            for(int i=0;i<result.length;i++) {
                if(i<len) {
                    for(int j=0;j<result.length;j++) {
                        if(j<len) {
                            result[i][j]=a1[i][j];
                        }else {
                            result[i][j] = a2[i][j-len];
                        }
                    }
                }else {
                    for(int j=0;j<result.length;j++) {
                        if(j<len) {
                            result[i][j]=a3[i-len][j];
                        }else {
                            result[i][j] = a4[i-len][j-len];
                        }
                    }
                }
            }
            return result;
        }
    public static void main(String[] args) {
        //测试分而治之的矩阵乘法
        Practice1_5 prac = new Practice1_5();
        int[][] a1 = new int[][] {{1,2,3,4},{4,3,2,1},{0,1,2,3},{3,2,1,0}};
        int[][] a2 = new int[][] {{1,2,3,4},{4,3,2,1},{0,1,2,3},{3,2,1,0}};
        prac.display(prac.matrix_multiply(a1, a2));
    }

}


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值