第四章 分治策略(2)

矩阵乘法的Strassen算法

求解两个矩阵的点乘积的算法

简单算法

时间复杂度为:Θ(n3)

简单的分治算法

假设矩阵的行列数相等的方阵,并且为2的幂。
递归地将两个矩阵分别平均分解为行列数为(n/2,n/2)的四个方阵。
已知:
在这里插入图片描述
在这里插入图片描述

直接的递归算法伪代码:
在这里插入图片描述
python实现(方阵的乘积):

import sys
import numpy as np

# index = 0
def multi_matrix(A,B):
    n= A.shape[0]
    if n == 1:
        C = A[0][0]*B[0][0]
        return C
    else:
        i = int(n/2)
        a11 = np.array([[col for col in row[:i]] for row in A[:i]])
        a12 = np.array([[col for col in row[i:]] for row in A[:i]])
        a21 = np.array([[col for col in row[:i]] for row in A[i:]])
        a22 = np.array([[col for col in row[i:]] for row in A[i:]])
        b11 = np.array([[col for col in row[:i]] for row in B[:i]])
        b12 = np.array([[col for col in row[i:]] for row in B[:i]])
        b21 = np.array([[col for col in row[:i]] for row in B[i:]])
        b22 = np.array([[col for col in row[i:]] for row in B[i:]])

        c11 = multi_matrix(a11,b11)+multi_matrix(a12,b21)
        c12 = multi_matrix(a11,b12)+multi_matrix(a12,b22)
        c21 = multi_matrix(a21,b11)+multi_matrix(a22,b21)
        c22 = multi_matrix(a21,b12)+multi_matrix(a22,b22)

        c1 = np.hstack((c11,c12))
        c2 = np.hstack((c21,c22))
        C = np.vstack((c1,c2))

        return C


if __name__ == "__main__":
    # A = list(map(int, sys.stdin.readline().strip().split(' ')))
    A = np.random.randint(0, 20, size=[4, 4])
    B = np.random.randint(0, 20, size=[4, 4])
    C = multi_matrix(A,B)
    print(A,'\n',B,'\n',C)

时间复杂度分析:

  1. 对于n=1,T(1) = Θ(1)
  2. 6-9行,共8次调用,8T(n/2)
  3. 6-9行,执行加法的次数为4次,复杂度为Θ(n2)
  4. 所以递归式:
    在这里插入图片描述

求解后发现,算法的时间复杂度仍为Θ(n3),并没有优于简单算法

Strassen方法

思想是降低递归的次数,8→7
步骤:
1.分解矩阵,Θ(1)
2 创建10个(n/2,n/2)的矩阵S1–S10,Θ(n2)
3 用1和2中产生的矩阵,递归计算矩阵积P1–P7
4 通过P矩阵的不同组合进行加减运算,得到最终结果。Θ(n2)

递归式:
在这里插入图片描述

求得时间复杂度为:Θ(n(lg7))

S1-S10:
在这里插入图片描述

在这里插入图片描述

P1-P7:
在这里插入图片描述

C11 = P5+P4-P2+P6
C12 = P1+P2
C21 = P3+P4
C22 = P5+P1-P3-P7

python 实现:

def strassen(A,B):
    n = A.shape[0]
    if n == 1:
        C = A[0][0] * B[0][0]
        return C
    else:
        i = int(n / 2)
        a11 = np.array([[col for col in row[:i]] for row in A[:i]])
        a12 = np.array([[col for col in row[i:]] for row in A[:i]])
        a21 = np.array([[col for col in row[:i]] for row in A[i:]])
        a22 = np.array([[col for col in row[i:]] for row in A[i:]])
        b11 = np.array([[col for col in row[:i]] for row in B[:i]])
        b12 = np.array([[col for col in row[i:]] for row in B[:i]])
        b21 = np.array([[col for col in row[:i]] for row in B[i:]])
        b22 = np.array([[col for col in row[i:]] for row in B[i:]])

        s1 = b12-b22
        s2 = a11 + a12
        s3 = a21+a22
        s4 = b21-b11
        s5 = a11+a22
        s6 = b11+b22
        s7 = a12-a22
        s8 = b21+b22
        s9 = a11-a21
        s10 = b11+b12

        p1 = strassen(a11,s1)
        p2 = strassen(s2,b22)
        p3 = strassen(s3,b11)
        p4 = strassen(a22,s4)
        p5 = strassen(s5,s6)
        p6 = strassen(s7,s8)
        p7 = strassen(s9,s10)

        c11 = p5+p4-p2+p6
        c12 = p1+p2
        c21 = p3+p4
        c22 = p5+p1-p3-p7

        c1 = np.hstack((c11, c12))
        c2 = np.hstack((c21, c22))
        C = np.vstack((c1, c2))

        return C
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值