《算法导论》——矩阵乘法Strassen算法

29 篇文章 0 订阅
8 篇文章 0 订阅

注:本文为《算法导论》中分治相关内容的笔记。对此感兴趣的读者还望支持原作者。

矩阵乘法

接触过线性代数的读者,对于矩阵乘法想必一定不陌生。若 A = ( a i j ) A=(a_{ij}) A=(aij) B = ( b i j ) B=(b_{ij}) B=(bij) n ∗ n n*n nn的方阵,则对 i , j , … , n i, j, \ldots, n i,j,,n,定义乘积 C = A ⋅ B C=A \cdot B C=AB中的元素 c i j c_{ij} cij为:

c i j = ∑ k = 1 n a i k b k j c_{ij}=\sum_{k=1}^{n}a_{ik}b_{kj} cij=k=1naikbkj

因此,我们可以根据矩阵乘法的定义给出矩阵乘法的伪代码。它接收 n ∗ n n * n nn的矩阵 A A A B B B,返回它们的乘积—— n ∗ n n * n nn的矩阵 C C C,并且假设每个矩阵都有一个属性 r o w s rows rows,表示矩阵的行数。

朴素算法

不难看出,由于三重for循环都恰好执行 n n n步,而第7行每次执行都花费常量时间。因此,SQUARE-MATRIX-MULTIPLY的时间复杂度为 θ ( n 3 ) \theta (n^3) θ(n3),即矩阵乘法的朴素实现需要花费 θ ( n 3 ) \theta (n^3) θ(n3)时间。你可能因此认为任何矩阵乘法都要花费 Ω ( n 3 ) \Omega (n^3) Ω(n3)时间,因为矩阵乘法的自然定义就需要进行这么多次的标量乘法。而在学术界,也的确在很长一段时间内,很少人敢设想一个算法能渐近快于平凡算法SQUARE-MATRIX-MULTIPLY,直至Strassen大神的出现。

算法流程

Strassen算法采用分治法解决矩阵乘积问题,并通过排列组合的技巧使得分治法产生的递归树不那么“茂盛”以减少矩阵乘法的次数。Strassen算法并不直观,它包含4个步骤:

  1. 将输入矩阵 A 、 B A、B AB和输出矩阵 C C C通过以下方式分解为 n 2 ∗ n 2 \frac{n}{2} * \frac{n}{2} 2n2n的子矩阵;
    A = [ A 11 A 12 A 21 A 22 ] , B = [ B 11 B 12 B 21 B 22 ] , C = [ C 11 C 12 C 21 C 22 ] A = \left [ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \\ \end{matrix} \right ], B = \left [ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \\ \end{matrix} \right ], C = \left [ \begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \\ \end{matrix} \right ] A=[A11A21A12A22],B=[B11B21B12B22],C=[C11C21C12C22]

  2. 创建10个 n 2 ∗ n 2 \frac{n}{2} * \frac{n}{2} 2n2n的矩阵 S 1 , S 2 , … , S 10 S_1, S_2, \ldots , S_{10} S1,S2,,S10,每个矩阵保存步骤1中创建的两个子矩阵的和或差,时间复杂度为 Θ ( n 2 ) \Theta (n^2) Θ(n2)

  3. 用步骤1中创建的子矩阵和步骤2中创建的10个矩阵,递归地计算7个矩阵积 P 1 , P 2 , … , P 7 P_1, P_2, \ldots , P_7 P1,P2,,P7。每个矩阵 P i P_i Pi都是 n 2 ∗ n 2 \frac{n}{2} * \frac{n}{2} 2n2n的;

  4. 通过 P i P_i Pi矩阵的不同组合进行加减计算,计算出矩阵 C C C的子矩阵 C 11 , C 12 , C 21 , C 22 C_{11}, C_{12}, C_{21}, C_{22} C11,C12,C21,C22,时间复杂度为 Θ ( n 2 ) \Theta(n^2) Θ(n2)

是不是感觉很抽象?一顿猛如虎的操作,就能完成矩阵乘积计算了?没错,就是这么。接下来,为了帮助大家掌握这种操作,就再看看Strassen算法的细节。在步骤2中,创建如下10个矩阵:

S 1 = B 12 − B 22 S_1 = B_{12} - B_{22} S1=B12B22

S 2 = A 11 + A 12 S_2 = A_{11} + A_{12} S2=A11+A12

S 3 = A 21 + A 22 S_3 = A_{21} + A_{22} S3=A21+A22

S 4 = B 21 − B 11 S_4 = B_{21} - B_{11} S4=B21B11

S 5 = A 11 + A 22 S_5 = A_{11} + A_{22} S5=A11+A22

S 6 = B 11 + B 22 S_6 = B_{11} + B_{22} S6=B11+B22

S 7 = A 12 − A 22 S_7 = A_{12} - A_{22} S7=A12A22

S 8 = B 21 + B 22 S_8 = B_{21} + B_{22} S8=B21+B22

S 9 = A 11 − A 21 S_9 = A_{11} - A_{21} S9=A11A21

S 10 = B 11 + B 22 S_{10} = B_{11} + B_{22} S10=B11+B22

由于必须进行10次 n 2 ∗ n 2 \frac{n}{2} * \frac{n}{2} 2n2n的加减法,因此,该步骤花费 Θ ( n 2 ) \Theta(n^2) Θ(n2)

在步骤三中,递归地计算7次 n 2 ∗ n 2 \frac{n}{2} * \frac{n}{2} 2n2n矩阵的乘法,如下所示:

P 1 = A 11 ⋅ S 1 = A 11 ⋅ B 12 − A 11 ⋅ B 22 P_1 = A_{11} \cdot S_1 = A_{11} \cdot B_{12} - A_{11} \cdot B_{22} P1=A11S1=A11B12A11B22

P 2 = S 2 ⋅ B 22 = A 11 ⋅ B 22 + A 12 ⋅ B 22 P_2 = S_2 \cdot B_{22} = A_{11} \cdot B_{22} + A_{12} \cdot B_{22} P2=S2B22=A11B22+A12B22

P 3 = S 3 ⋅ B 11 = A 21 ⋅ B 11 + A 22 ⋅ B 11 P_3 = S_3 \cdot B_{11} = A_{21} \cdot B_{11} + A_{22} \cdot B_{11} P3=S3B11=A21B11+A22B11

P 4 = A 22 ⋅ S 4 = A 22 ⋅ B 21 − A 22 ⋅ B 11 P_4 = A_{22} \cdot S_4 = A_{22} \cdot B_{21} - A_{22} \cdot B_{11} P4=A22S4=A22B21A22B11

P 5 = S 5 ⋅ S 6 = A 11 ⋅ B 11 + A 11 ⋅ B 22 + A 22 ⋅ B 11 + A 22 ⋅ B 22 P_5 = S_5 \cdot S_6 = A_{11} \cdot B_{11} + A_{11} \cdot B_{22} + A_{22} \cdot B_{11} + A_{22} \cdot B_{22} P5=S5S6=A11B11+A11B22+A22B11+A22B22

P 6 = S 7 ⋅ S 8 = A 12 ⋅ B 21 + A 12 ⋅ B 22 − A 22 ⋅ B 21 − A 22 ⋅ B 22 P_6 = S_7 \cdot S_8 = A_{12} \cdot B_{21} + A_{12} \cdot B_{22} - A_{22} \cdot B_{21} - A_{22} \cdot B_{22} P6=S7S8=A12B21+A12B22A22B21A22B22

P 7 = S 9 ⋅ S 1 0 = A 11 ⋅ B 11 + A 11 ⋅ B 12 − A 21 ⋅ B 11 − A 21 ⋅ B 12 P_7 = S_9 \cdot S_10 = A_{11} \cdot B_{11} + A_{11} \cdot B_{12} - A_{21} \cdot B_{11} - A_{21} \cdot B_{12} P7=S9S10=A11B11+A11B12A21B11A21B12

步骤4对步骤3创建的 P i P_i Pi矩阵进行加减法运算,计算出 C C C的4个 n 2 ∗ n 2 \frac{n}{2} * \frac{n}{2} 2n2n的子矩阵。

C 11 = P 5 + P 4 − P 2 + P 6 = A 11 ⋅ B 11 + A 12 ⋅ B 21 C_{11} = P_5 + P_4 - P_2 + P_6 = A_{11} \cdot B_{11} + A_{12} \cdot B_{21} C11=P5+P4P2+P6=A11B11+A12B21

C 12 = P 1 + P 2 = A 11 ⋅ B 12 + A 12 ⋅ B 22 C_{12} = P_1 + P_2 = A_{11} \cdot B_{12} + A_{12} \cdot B_{22} C12=P1+P2=A11B12+A12B22

C 21 = P 3 + P 4 = A 21 ⋅ B 11 + A 22 ⋅ B 21 C_{21} = P_3 + P_4 = A_{21} \cdot B_{11} + A_{22} \cdot B_{21} C21=P3+P4=A21B11+A22B21

C 22 = P 5 + P 1 − P 3 − P 7 = A 22 ⋅ B 22 + A 21 ⋅ B 12 C_{22} = P_5 + P_1 - P_3 - P_7 = A_{22} \cdot B_{22} + A_{21} \cdot B_{12} C22=P5+P1P3P7=A22B22+A21B12

如此,我们便获得矩阵 A A A B B B的乘积——矩阵 C C C

算法分析

之前说过,Strassen算法的时间复杂度是优于朴素计算的,可是,它到底是多少呢?我们不妨再回到Strassen算法的流程。当 n > 1 n > 1 n>1时,步骤1、2和4共花费 θ ( n 2 ) \theta(n^2) θ(n2)时间,步骤3要求7次 n 2 ∗ n 2 \frac{n}{2} * \frac{n}{2} 2n2n矩阵的乘法。因此,我们得到如下描述Strassen算法运行时间 T ( n ) T(n) T(n)的递归式:

T ( n ) = { θ ( 1 ) 若 n = 1 7 T ( n / 2 ) + θ ( n 2 ) 若 n > 1 T(n)=\left\{ \begin{aligned} & \theta(1) & 若n = 1\\ & 7T(n/2) + \theta(n^2) & 若n > 1\\ \end{aligned} \right. T(n)={θ(1)7T(n/2)+θ(n2)n=1n>1

求解上式可得, T ( n ) = θ ( n lg ⁡ 7 ) T(n) = \theta(n^{\lg7}) T(n)=θ(nlg7)

算法实现

废话千句,不如代码两行,接下来直接上Strassen算法的实现。(注意,如果 n n n不是2的幂,可以采取对原矩阵填充0的方式,使 n n n扩展到2的幂)。

Strassen算法

算法总结

Strassen算法发表于1969年,它的发表引起了很大的轰动。在此之前,很少人敢设想一个算法能渐近快于平凡算法SQUARE-MATRIX-MULTIPLY。矩阵乘法的上界自此被改进了。到目前为止, n ∗ n n*n nn矩阵相乘的渐近复杂性最优的算法是Coppersmith和Winograd提出的,运行时间是 O ( n 2.376 ) O(n^{2.376}) O(n2.376)

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值