用分治策略实现Strassen矩阵乘法

 

(1)这里我选择使用vector容器,创建一个二维的 std::vector 容器,存储整数类型的数据,用于存储矩阵元素。

vector来表示矩阵可以提供更多的灵活性、安全性和便利性,特别是在需要动态大小、避免内存管理问题的情况下。

 

定义如下:

vector<vector<int>> data;

 

(2)Strassen 矩阵乘法是一种递归分治算法,用于将大矩阵的乘法分解为较小矩阵的乘法。

 

传统矩阵乘法的计算思路是基于矩阵的乘法定义,按照行乘以列的方式逐个元素相乘,然后将结果累加得到新矩阵的元素值。

假设现在有一个rows x cols 的矩阵(记为矩阵 A)和一个 other.rows x other.cols 的矩阵(记为矩阵 B),要求矩阵 A 的列数(cols)必须等于矩阵 B 的行数(other.rows)才能进行乘法操作。

普通乘法的过程如下:

创建一个新的结果矩阵(记为矩阵 result),其维度为 rows x other.cols,即行数等于矩阵 A 的行数,列数等于矩阵 B 的列数。

对于矩阵 C 中的每个元素 C[i][j],其中 i 表示行索引,j 表示列索引,计算如下:

 

C[i][j] = A[i][0] * B[0][j] + A[i][1] * B[1][j] + ... + A[i][cols-1] * B[cols-1][j]

这个计算过程实际上是在遍历矩阵 A 的第 i 行和矩阵 B 的第 j 列,将对应元素相乘并将它们累加起来,得到矩阵 C 的元素值。

重复上述步骤,直到计算完矩阵 C 的所有元素。

伪代码如下:

Matrix result(rows, other.cols);

        for (int i = 0; i < rows; i++) {

            for (int j = 0; j < other.cols; j++) {

                for (int k = 0; k < cols; k++) {

                    result.data[i][j] += data[i][k] * other.data[k][j];

                }

            }

        }

        return result;

可以很容易看出需要三个嵌套循环,得到时间复杂度为 O(rows * other.cols * cols)

假设均为NxN矩阵,则时间复杂度为O(n^3)

 

 

Strassen大矩阵乘法基本思路是将矩阵分割成较小的子矩阵,然后通过一系列中间矩阵的计算来组合子矩阵的乘积,从而得到最终的结果。

首先,设置递归终止条件。如果矩阵大小变得足够小(通常是一个阈值,例如 1x1 矩阵),则直接执行传统的矩阵乘法。

如果矩阵不是足够小,则将每个输入矩阵 A 和 B 分割成四个相等大小的子矩阵,每个子矩阵都是原矩阵的一部分。

         for (int i = 0; i < halfRows; i++) {

                for (int j = 0; j < halfCols; j++) {

                    A11.data[i][j] = A.data[i][j];

                    A12.data[i][j] = A.data[i][j + halfCols];

                    A21.data[i][j] = A.data[i + halfRows][j];

                    A22.data[i][j] = A.data[i + halfRows][j + halfCols];

 

                    B11.data[i][j] = B.data[i][j];

                    B12.data[i][j] = B.data[i][j + halfCols];

                    B21.data[i][j] = B.data[i + halfRows][j];

                    B22.data[i][j] = B.data[i + halfRows][j + halfCols];

                }

            }

将输入矩阵分割成四个子矩阵需要常数时间,这是一个线性开销,与矩阵大小无关。

计算中间矩阵M1到M7,这里使用递归调用,不断调用Strassen乘法

       Matrix M1 = (A11 + A22).strassenMultiply(B11 + B22);

       Matrix M2 = (A21 + A22).strassenMultiply(B11);

       Matrix M3 = A11.strassenMultiply(B12 - B22);

       Matrix M4 = A22.strassenMultiply(B21 - B11);

       Matrix M5 = (A11 + A12).strassenMultiply(B22);

       Matrix M6 = (A21 - A11).strassenMultiply(B11 + B12);

       Matrix M7 = (A12 - A22).strassenMultiply(B21 + B22);

在递归步骤中,有七个子问题,每个子问题涉及的矩阵维度都是原始维度的一半。因此,递归深度为 log₂(n)。

在每一层递归中,我们需要计算这七个子问题,每个子问题需要进行一次矩阵乘法,这一步的时间复杂度为 O(n²)

计算结果矩阵

       Matrix C11 = M1 + M4 - M5 + M7;

        Matrix C12 = M3 + M5;

        Matrix C21 = M2 + M4;

        Matrix C22 = M1 - M2 + M3 + M6;

合并矩阵

Matrix result(rows, other.cols);

            for (int i = 0; i < halfRows; i++) {

                for (int j = 0; j < halfCols; j++) {

                    result.data[i][j] = C11.data[i][j];

                    result.data[i][j + halfCols] = C12.data[i][j];

                    result.data[i + halfRows][j] = C21.data[i][j];

                    result.data[i + halfRows][j + halfCols] = C22.data[i][j];

                }

            }

         return result;

     合并过程与分割过程同理

 

  1. 实现Strassen算法

     定义Matrix矩阵类

6c78361b34754ebdb1831cda70e1d554.png

 

 

 私有成员变量有row、cols表示矩阵行和列,使用二维动态数组data存储矩阵中的元素

 

87a7c0f6acf744bcb11618c0f8575bcd.png

 

构造函数,调用resize()生成指定行列的默认初始值为0的矩阵

bdd7e0f671b940639b07df80de8c2539.png

简单的set、get函数

6e23f479886d45568aaee1117a637fca.png

为了实现Strassen矩阵乘法,需要首先实现关于矩阵加法、减法和乘法的函数重载

以方便进行分割矩阵后的中间操作

19d26c5f151f417cb16c9f4fb289a43f.png

     构造函数,用于给numRow行numCol列矩阵指定初始值,便于Strassen矩阵中设置递归终止条件44f61e3ee77f43b99acf00975b563385.png

Strassen乘法函数

首先设置递归终止条件

b23cfc11dd764dc2bab4f1358b17b983.png

定义子矩阵

5ef4fedad7754620a6d39d0bddd211cd.png

分割子矩阵

2a00f87cfc954e6c8534996f52375c9b.png

计算中间矩阵M1到M7和结果矩阵C11到C22

c08e35f24923437aa185b0d21e271130.png

合并C11至C22矩阵到结果矩阵

f4fd4998f84c444facb3cff862fd6aca.png

打印函数

e4aff158a8234fb4a9437c3cbb43a4ed.png

现在可以实现矩阵加法、减法、普通乘法和Strassen大矩阵乘法了

  1. 分析算法的时空复杂性

使用随机函数生成NxN列矩阵,以便后续使用Clock函数计算cpu执行时间

bdedfeb183db4cc09722ce8b6bff057f.png

    现在可以比较普通矩阵乘法与Strassen矩阵乘法的时间开销了

对于空间复杂性,首先需要考虑代码中使用的数据结构

Matrix 对象的空间复杂度

1.int rows 和 int cols:这两个成员变量占用常量空间,不随矩阵大小变化而变化。

2.vector<vector<int>> data:这是一个二维动态数组,其空间复杂度取决于矩阵的大小,即 rows * cols。因此,它的空间复杂度为 O(rows * cols)。

    strassenMultiply 函数的空间复杂度:

1.递归深度取决于输入矩阵的大小。在每一层递归中,都会创建一些临时矩阵,包括 A11, A12, A21, A22, B11, B12, B21, B22 以及 M1, M2, M3, M4, M5, M6, M7。

2.每个子矩阵的空间复杂度取决于其大小。如果输入矩阵的大小是 N x N,则每个子矩阵的大小是 N/2 x N/2。

3.从 N 减半到 1,需要经过 log2(N) 步,因此,递归深度是 log2(N)。

对于每个矩阵的子矩阵 A11, A12, A21, A22, B11, B12, B21, B22,每个子矩阵的大小都是原始矩阵的一半。这意味着总共需要额外的空间来存储这些子矩阵。假设原始矩阵的大小为 N x N,那么递归到一阶矩阵的情况下,子矩阵的最大大小为 (N/2) x (N/2)。因此,子矩阵所需的总空间为 O(N^2/4) = O(N^2)。

在每个递归层级中,会创建临时矩阵 M1, M2, M3, M4, M5, M6, M7,它们的大小与子矩阵相同,即 (N/2) x (N/2)。因此,在每个递归层级中,这些临时矩阵所需的总空间也是 O(N^2)。

结果矩阵 result 的大小与原始矩阵相同,即 N x N。因此,结果矩阵所需的空间也是 O(N^2)。

      这样每一层递归都会导致O(N^2)的额外空间复杂度,那么递归log2(N)步后,总的空间复杂度应该是O(log2(N) * N^2)

首先测试算法是否可以正确执行

首先让用户生成A、B矩阵

2929c1bdcf7e49809d5ec2265b6badb5.png

然后测试加减以及普通乘法Strassen乘法运算

0349206f977647a0ab2d578db6d39b8d.png

测试效果

3d6eb4e632084ba48e45ef2c9804dd66.png

接下来测试时间开销

  1fb2113f0e3e481eaec5f7f3d62d0adf.png

发现在N不是很大的时候普通矩阵算法cpu执行时间要小

c62c3c6f3241427390b29f2e81ffb6eb.png

在N很大时,发现在普通算法执行结束后Strassen算法很长一段时间后仍然无法输出

8ad31b40dd074e35921fde7cdcf63f36.png

上文我们发现使用Strassen算法在处理大矩阵时,耗时非但没有减少,反而增多,在n=1024时就无法输出,效果没有普通矩阵乘法算法好。网上查阅资料,将我认为的原因归纳如下:

  1. 额外的矩阵分割和合并开销:Strassen 算法需要将输入矩阵分割成四个子矩阵,并且需要在计算结果时将这些子矩阵合并成最终的结果矩阵。这些分割和合并操作涉及到额外的内存分配和数据复制操作,需要创建大量的动态二维数组,其中分配堆内存空间将占用大量计算时间,这会增加算法的开销。
  2. 递归开销:Strassen 算法是一个递归算法,它在递归的每一层都要执行多次矩阵分割、计算中间矩阵和合并操作。递归调用本身会引入函数调用和堆栈开销,特别是在小矩阵的情况下,递归开销可能会占据算法执行时间的相当一部分。
  3. 小矩阵的限制:Strassen 算法通常在矩阵达到一定大小时才能比传统矩阵乘法更高效。对于小型矩阵,Strassen 算法的额外开销可能会超过其性能提升,因此在小矩阵上执行 Strassen 算法可能不如传统矩阵乘法效率高。

    改进策略

    给Strassen算法设定一个下界。当n<界限时,使用普通乘法计算矩阵,而不继续分治递归

代码修改如下,这里先假设下界为512

5495d6a4832a4911a06e2aea604c7c91.png

运行结果如下

dc5830e00b214b11bade1d7df80e40f0.png

设定下界为256

运行结果如下

8b8b5f9f8be44d44844e35a874051aa1.png

基于不同的矩阵阶数,下界应该有所调整,不过可以看得出来在阶数高的矩阵乘法上Strassen矩阵乘法具有明显的优势。

 

  • 16
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值