贪心算法-Strassen矩阵乘法

两个矩阵的乘法学过线性代数的都知道怎么求,一般来说复杂度为O(N^3).直接给出标准的算法

代码:

public class MartixMultiply {
 public static int[][] multiply(int[][] a, int[][] b) {
     int n = a.length;
     int[][] c = new int[n][n];

     for (int i = 0; i < n; i++)
         // 初始化
         for (int j = 0; j < n; j++)
             c[i][j] = 0;

     for (int i = 0; i < n; i++)
         for (int j = 0; j < n; j++)
             for (int k = 0; k < n; k++)
                 c[i][j] += a[i][k] * b[k][j];

     return c;
 }

 public static void main(String[] args) {
     int[][] a = { { 1, 2 }, { 3, 4 } };
     int[][] b = { { 3, 4 }, { 7, 2 } };
     int[][] c = multiply(a, b);

     System.out.println(c[0][0] + " " + c[0][1] + " " + c[1][0] + " "
             + c[1][1]);
 }
}

Strassen提出了算法打破了O(N^3)的屏障.用到分治算法,把矩阵分为4块.

这里写图片描述
这里写图片描述

其中
这里写图片描述
可以得到递推关系T(N)=7T(N/2)+O(N²),依据主定理得到解T(N)=O(N^2.81).
这儿不做出证明,显然这用到了分治法的思想

代码:

public class MartixMultiply {
 public static int[][] StrassenMultiply(int[][] a, int[][] b) {
     int[][] result = new int[a.length][b.length];
     if (a.length == 2)
         return multiply(a, b);// 如果是2阶的 就结束递归 用传统方法
     // a的四个子矩阵
     int[][] A00 = divide(a, 1);
     int[][] A01 = divide(a, 2);
     int[][] A10 = divide(a, 3);
     int[][] A11 = divide(a, 4);
     // b的四个子矩阵
     int[][] B00 = divide(b, 1);
     int[][] B01 = divide(b, 2);
     int[][] B10 = divide(b, 3);
     int[][] B11 = divide(b, 4);

     int[][] m1 = StrassenMultiply(addArrays(A00, A11), addArrays(B00, B11));
     int[][] m2 = StrassenMultiply(addArrays(A10, A11), B00);
     int[][] m3 = StrassenMultiply(A00, subArrays(B01, B11));
     int[][] m4 = StrassenMultiply(A11, subArrays(B10, B00));
     int[][] m5 = StrassenMultiply(addArrays(A00, A01), B11);
     int[][] m6 = StrassenMultiply(subArrays(A10, A00), addArrays(B00, B01));
     int[][] m7 = StrassenMultiply(subArrays(A01, A11), addArrays(B10, B11));

     int[][] C00 = addArrays(m7, subArrays(addArrays(m1, m4), m5));// m1+m4-m5+m7
     int[][] C01 = addArrays(m3, m5); // m3+m5
     int[][] C10 = addArrays(m2, m4); // m2+m4
     int[][] C11 = addArrays(m6, subArrays(addArrays(m1, m3), m2));// m1+m3-m2+m6

     // 将四个矩阵合并起来
     Merge(result, C00, 1);
     Merge(result, C01, 2);
     Merge(result, C10, 3);
     Merge(result, C11, 4);

     return result;
 }

 // /分割得到子矩阵
 private static int[][] divide(int[][] a, int flag) {
     int[][] result = new int[a.length / 2][a.length / 2];
     switch (flag) {
     case 1:
         for (int i = 0; i < a.length / 2; i++)
             for (int j = 0; j < a.length / 2; j++)
                 result[i][j] = a[i][j];
         break;
     case 2:
         for (int i = 0; i < a.length / 2; i++)
             for (int j = a.length / 2; j < a.length; j++)
                 result[i][j - a.length / 2] = a[i][j];
         break;
     case 3:
         for (int i = a.length / 2; i < a.length; i++)
             for (int j = 0; j < a.length / 2; j++)
                 result[i - a.length / 2][j] = a[i][j];
         break;
     case 4:
         for (int i = a.length / 2; i < a.length; i++)
             for (int j = a.length / 2; j < a.length; j++)
                 result[i - a.length / 2][j - a.length / 2] = a[i][j];
         break;
     }
     return result;
 }

 // 矩阵加法
 private static int[][] addArrays(int[][] a, int[][] b) {
     int[][] result = new int[a.length][a.length];
     for (int i = 0; i < result.length; i++) {
         for (int j = 0; j < result.length; j++) {
             result[i][j] = a[i][j] + b[i][j];
         }
     }
     return result;
 }

 // 矩阵减法
 private static int[][] subArrays(int[][] a, int[][] b) {
     int[][] result = new int[a.length][a.length];
     for (int i = 0; i < result.length; i++) {
         for (int j = 0; j < result.length; j++) {
             result[i][j] = a[i][j] - b[i][j];
         }
     }
     return result;
 }

 // 将b复制到a的指定位置
 private static void Merge(int[][] a, int[][] b, int flag) {
     switch (flag) {
     case 1:
         for (int i = 0; i < a.length / 2; i++)
             for (int j = 0; j < a.length / 2; j++)
                 a[i][j] = b[i][j];
         break;
     case 2:
         for (int i = 0; i < a.length / 2; i++)
             for (int j = a.length / 2; j < a.length; j++)
                 a[i][j] = b[i][j - a.length / 2];
         break;
     case 3:
         for (int i = a.length / 2; i < a.length; i++)
             for (int j = 0; j < a.length / 2; j++)
                 a[i][j] = b[i - a.length / 2][j];
         break;
     case 4:
         for (int i = a.length / 2; i < a.length; i++)
             for (int j = a.length / 2; j < a.length; j++)
                 a[i][j] = b[i - a.length / 2][j - a.length / 2];
         break;
     }
 }

 // 常规做法
 public static int[][] multiply(int[][] a, int[][] b) {
     int n = a.length;
     int[][] c = new int[n][n];

     for (int i = 0; i < n; i++)
         // Initialization
         for (int j = 0; j < n; j++)
             c[i][j] = 0;

     for (int i = 0; i < n; i++)
         for (int j = 0; j < n; j++)
             for (int k = 0; k < n; k++)
                 c[i][j] += a[i][k] * b[k][j];

     return c;
 }

 public static void main(String[] args) {
     int[][] a = { { 1, 2, 6, 7 }, { 3, 4, 5, 4 }, { 5, 8, 3, 8 },
             { -6, 4, 3, 9 } };
     int[][] b = { { 3, 4, 9, 0 }, { 7, 2, -5, -6 }, { 0, 7, -4, 6 },
             { -6, 3, -5, 4 } };
     int[][] c = multiply(a, b);

     System.out.println(c[0][0] + " " + c[0][1] + " " + c[1][0] + " "
             + c[1][1]);
 }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值