Strassen:一种高效的矩阵相乘算法

Q1:如何提高矩阵相乘效率?

矩阵乘法是种极其耗时的运算。
以C = A • B为例,其中A和B都是 n x n 的方阵。根据矩阵乘法的定义,计算过程如下:
在这里插入图片描述

A1: 简单分治法

前提:假定A,B都是n等于2的次幂的方阵

基本思路:计算C=A*B时,将C,A,B矩阵进行分块操作,对每个分块的矩阵进行乘法运
算,运算完毕后重新对得到的C11,C12,C21,C22进行组合操作。

确定递归终止条件:当分块矩阵得到的阶数为1 时,得到的C即是A和B中两个元素的乘积。

在这里插入图片描述
每个公式需要计算两次矩阵乘法和一次矩阵加法,使用T(n)表示 n x n 矩阵乘法的时间复杂度,那么我们可以根据上面的分解得到一个递推公式。

T(n) = 8T(n/2) + Θ(n²)

其中,8T(n/2)表示8次矩阵乘法,而且相乘的矩阵规模降到了n/2。Θ(n²)表示4次矩阵加法的时间复杂度以及合并C矩阵的时间复杂度。

要想计算出T(n)并不复杂,可以采用画递归树的方式计算,或采用“主方法”直接计算。
结果是:

T(n) = Θ(n³)

可见,简单的分治策略并没有起到加速运算的效果。

简单分治策略为什么无法提高速度?

因为分解后的问题包含了8次矩阵相乘和4次矩阵相加,就是这8次矩阵相乘导致了速度不能提升。

于是我们想到能不能减少矩阵相乘的次数,取而代之的是矩阵相加的次数增加?
Strassen正是利用了这一点。

A2: Strassen

在这里插入图片描述
Strassen计算量是7个矩阵乘法和18个矩阵加法。虽然矩阵加法增加了好几倍,而矩阵乘法只减小了1个,但在数量级面前,18个加法仍然渐进快于1个乘法。

使用递归树或主方法可以计算出结果:T(n) = Θ(nlg7) ≈ Θ(n2.81)
在这里插入图片描述

代码:
在这里插入图片描述

output:
在这里插入图片描述

Q2:如何计算任意偶数阶矩阵相乘?

对于任意偶数n,总有n=m*(2^k)
将矩阵分成m*m个2^k阶矩阵。
举例:6X6矩阵相乘,可分为9个2x2矩阵,大矩阵用传统算法,小矩阵用Strassen。

在这里插入图片描述

Q3:对于任意正整数n,如何求n阶矩阵相乘?

将不是2的次幂的矩阵扩展成2的次幂的矩阵,在多出的行和列上添上0元素,在计算结果重新组合成c后,对c矩阵多出的行和列上的0元素舍去。

验证:
在这里插入图片描述
因此在原先的基础上增加了矩阵的拓展和缩略函数。
在主函数中,首先对输入矩阵A,B的阶数进行判断,如果是2的次幂则不用进行任何操作,直接用普通的Strassen算法,如果不是2的次幂,先对A,B进行矩阵拓展,在计算得到的结果后进行矩阵缩略。

over~

  • 3
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
传统方法和Strassen算法是两种不同的矩阵相乘算法。传统方法的时间复杂度为O(n^3),而Strassen算法的时间复杂度为O(n^log2(7)),其中log2(7)约等于2.807。 因此,为了结合两种算法的优点,我们可以采用一个分治的思想,对矩阵的大小进行适当的划分,在两种算法之间进行选择。 具体实现如下: ```java public class MatrixMultiply { // 传统矩阵相乘算法 public static int[][] multiply(int[][] A, int[][] B) { int m = A.length; int n = A[0].length; int p = B[0].length; int[][] C = new int[m][p]; for (int i = 0; i < m; i++) { for (int j = 0; j < p; j++) { for (int k = 0; k < n; k++) { C[i][j] += A[i][k] * B[k][j]; } } } return C; } // Strassen算法 public static int[][] strassen(int[][] A, int[][] B) { int n = A.length; if (n <= 64) { return multiply(A, B); // 当矩阵大小小于等于64时,使用传统算法 } int[][] A11 = new int[n / 2][n / 2]; int[][] A12 = new int[n / 2][n / 2]; int[][] A21 = new int[n / 2][n / 2]; int[][] A22 = new int[n / 2][n / 2]; int[][] B11 = new int[n / 2][n / 2]; int[][] B12 = new int[n / 2][n / 2]; int[][] B21 = new int[n / 2][n / 2]; int[][] B22 = new int[n / 2][n / 2]; // 将矩阵A、B分成四个子矩阵 for (int i = 0; i < n / 2; i++) { for (int j = 0; j < n / 2; j++) { A11[i][j] = A[i][j]; A12[i][j] = A[i][j + n / 2]; A21[i][j] = A[i + n / 2][j]; A22[i][j] = A[i + n / 2][j + n / 2]; B11[i][j] = B[i][j]; B12[i][j] = B[i][j + n / 2]; B21[i][j] = B[i + n / 2][j]; B22[i][j] = B[i + n / 2][j + n / 2]; } } // 计算7个子矩阵 int[][] M1 = strassen(add(A11, A22), add(B11, B22)); int[][] M2 = strassen(add(A21, A22), B11); int[][] M3 = strassen(A11, sub(B12, B22)); int[][] M4 = strassen(A22, sub(B21, B11)); int[][] M5 = strassen(add(A11, A12), B22); int[][] M6 = strassen(sub(A21, A11), add(B11, B12)); int[][] M7 = strassen(sub(A12, A22), add(B21, B22)); // 计算结果矩阵C的四个子矩阵 int[][] C11 = add(sub(add(M1, M4), M5), M7); int[][] C12 = add(M3, M5); int[][] C21 = add(M2, M4); int[][] C22 = add(sub(add(M1, M3), M2), M6); // 将四个子矩阵合并成一个大矩阵 int[][] C = new int[n][n]; for (int i = 0; i < n / 2; i++) { for (int j = 0; j < n / 2; j++) { C[i][j] = C11[i][j]; C[i][j + n / 2] = C12[i][j]; C[i + n / 2][j] = C21[i][j]; C[i + n / 2][j + n / 2] = C22[i][j]; } } return C; } // 矩阵加法 public static int[][] add(int[][] A, int[][] B) { int n = A.length; int[][] C = new int[n][n]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { C[i][j] = A[i][j] + B[i][j]; } } return C; } // 矩阵减法 public static int[][] sub(int[][] A, int[][] B) { int n = A.length; int[][] C = new int[n][n]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { C[i][j] = A[i][j] - B[i][j]; } } return C; } // 随机生成一个n*n的矩阵 public static int[][] generateMatrix(int n) { int[][] A = new int[n][n]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { A[i][j] = (int) (Math.random() * 10); } } return A; } // 打印矩阵 public static void printMatrix(int[][] A) { int n = A.length; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { System.out.print(A[i][j] + " "); } System.out.println(); } } public static void main(String[] args) { int n = 8; int[][] A = generateMatrix(n); int[][] B = generateMatrix(n); System.out.println("矩阵A:"); printMatrix(A); System.out.println("矩阵B:"); printMatrix(B); int[][] C = strassen(A, B); System.out.println("矩阵C:"); printMatrix(C); } } ``` 在上述代码中,程序首先生成两个大小为n\*n的随机矩阵A和B,然后调用strassen方法计算它们的乘积。当矩阵大小小于等于64时,程序使用传统矩阵相乘算法。否则,程序将矩阵A和B分成四个子矩阵,递归地调用strassen方法计算它们的乘积,最后将结果矩阵的四个子矩阵合并成一个大矩阵

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值