【分治算法】【Python实现】Strassen矩阵乘法

因上努力

个人主页:丷从心·

系列专栏:分治算法

学习指南:算法学习指南

果上随缘


问题描述

  • A A A B B B是两个 n × n n \times n n×n矩阵, A A A B B B的乘积矩阵 C C C中元素 c i j = ∑ k = 1 n a i k b k j c_{ij} = \displaystyle\sum\limits_{k = 1}^{n}{a_{ik} b_{kj}} cij=k=1naikbkj
  • 每计算 C C C的一个元素 c i j c_{ij} cij,需要做 n n n次乘法和 n − 1 n - 1 n1次加法,求出矩阵 C C C n 2 n^{2} n2个元素所需的时间为 O ( n 3 ) O(n^{3}) O(n3)

基础算法

  • 假设 n n n 2 2 2的幂,将矩阵 A A A B B B C C C中每个矩阵都分块成 4 4 4个大小相等的子矩阵,每个子矩阵都是 n / 2 × n / 2 n / 2 \times n / 2 n/2×n/2的方阵

∣ C 11 C 12 C 21 C 22 ∣ = ∣ A 11 A 12 A 21 A 22 ∣ ∣ B 11 B 12 B 21 B 22 ∣ \begin{vmatrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{vmatrix} = \begin{vmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{vmatrix} \begin{vmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{vmatrix} C11C21C12C22 = A11A21A12A22 B11B21B12B22

C 11 = A 11 B 11 + A 12 B 21 C 12 = A 11 B 12 + A 12 B 22 C 21 = A 21 B 11 + A 22 B 21 C 22 = A 21 B 12 + A 22 B 22 C_{11} = A_{11} B_{11} + A_{12} B_{21} \\ C_{12} = A_{11} B_{12} + A_{12} B_{22} \\ C_{21} = A_{21} B_{11} + A_{22} B_{21} \\ C_{22} = A_{21} B_{12} + A_{22} B_{22} C11=A11B11+A12B21C12=A11B12+A12B22C21=A21B11+A22B21C22=A21B12+A22B22

时间复杂性
  • 计算 2 2 2 n n n阶方阵的乘积转化为计算 8 8 8 n / 2 n / 2 n/2阶方阵的乘积和 4 4 4 n / 2 n / 2 n/2阶方阵的加法, 2 2 2 n / 2 × n / 2 n / 2 \times n / 2 n/2×n/2矩阵的加法显然可以在 O ( n 2 ) O(n^{2}) O(n2)时间内完成

T ( n ) = { O ( 1 ) n = 2 8 T ( n / 2 ) + O ( n 2 ) n > 2 T(n) = \begin{cases} O(1) & n = 2 \\ 8 T(n / 2) + O(n^{2}) & n > 2 \end{cases} T(n)={O(1)8T(n/2)+O(n2)n=2n>2

T ( n ) = O ( n 3 ) T(n) = O(n^{3}) T(n)=O(n3)


Strassen算法

  • Strassen算法只用了 7 7 7次乘法运算,但增加了加减法的运算次数

M 1 = A 11 ( B 12 − B 22 ) M 2 = ( A 11 + A 12 ) B 22 M 3 = ( A 21 + A 22 ) B 11 M 4 = A 22 ( B 21 − B 11 ) M 5 = ( A 11 + A 22 ) ( B 11 + B 22 ) M 6 = ( A 12 − A 22 ) ( B 21 + B 22 ) M 7 = ( A 11 − A 21 ) ( B 11 + B 12 ) M_{1} = A_{11} (B_{12} - B_{22}) \\ M_{2} = (A_{11} + A_{12}) B_{22} \\ M_{3} = (A_{21} + A_{22}) B_{11} \\ M_{4} = A_{22} (B_{21} - B_{11}) \\ M_{5} = (A_{11} + A_{22})(B_{11} + B_{22}) \\ M_{6} = (A_{12} - A_{22})(B_{21} + B_{22}) \\ M_{7} = (A_{11} - A_{21})(B_{11} + B_{12}) M1=A11(B12B22)M2=(A11+A12)B22M3=(A21+A22)B11M4=A22(B21B11)M5=(A11+A22)(B11+B22)M6=(A12A22)(B21+B22)M7=(A11A21)(B11+B12)

C 11 = M 5 + M 4 − M 2 + M 6 C 12 = M 1 + M 2 C 21 = M 3 + M 4 C 22 = M 5 + M 1 − M 3 − M 7 C_{11} = M_{5} + M_{4} - M_{2} + M_{6} \\ C_{12} = M_{1} + M_{2} \\ C_{21} = M_{3} + M_{4} \\ C_{22} = M_{5} + M_{1} - M_{3} - M_{7} C11=M5+M4M2+M6C12=M1+M2C21=M3+M4C22=M5+M1M3M7

时间复杂性
  • Strassen算法用了 7 7 7次对于 n / 2 n / 2 n/2阶矩阵乘积的递归调用和 18 18 18 n / 2 n / 2 n/2阶矩阵的加减运算

T ( n ) = { O ( 1 ) n = 2 7 T ( n / 2 ) + O ( n 2 ) n > 2 T(n) = \begin{cases} O(1) & n = 2 \\ 7 T(n / 2) + O(n^{2}) & n > 2 \end{cases} T(n)={O(1)7T(n/2)+O(n2)n=2n>2

T ( n ) = O ( n log ⁡ 7 ) ≈ O ( n 2.81 ) T(n) = O(n^{\log{7}}) \approx O(n^{2.81}) T(n)=O(nlog7)O(n2.81)


问题时间复杂性

  • H o p c r o f t Hopcroft Hopcroft K e r r Kerr Kerr已经证明:计算 2 2 2 2 × 2 2 \times 2 2×2矩阵的乘积, 7 7 7次乘法是必要的
  • 目前最好的计算时间上界是 O ( n 2.376 ) O(n^{2.376}) O(n2.376),所知的矩阵乘法的最好下界仍是它的平凡下界 Ω ( n 2 ) \Omega(n^{2}) Ω(n2)

Python实现

import numpy as np


def strassen_matrix_multiply(a, b):
    n = a.shape[0]

    # 如果输入矩阵的维度小于等于阈值, 使用传统的矩阵乘法
    if n <= 128:
        return np.dot(a, b)

    # 将输入矩阵划分为四个子矩阵
    mid = n // 2

    a11 = a[:mid, :mid]
    a12 = a[:mid, mid:]
    a21 = a[mid:, :mid]
    a22 = a[mid:, mid:]

    b11 = b[:mid, :mid]
    b12 = b[:mid, mid:]
    b21 = b[mid:, :mid]
    b22 = b[mid:, mid:]

    # 递归计算七个矩阵乘法
    m1 = strassen_matrix_multiply(a11, b12 - b22)
    m2 = strassen_matrix_multiply(a11 + a12, b22)
    m3 = strassen_matrix_multiply(a21 + a22, b11)
    m4 = strassen_matrix_multiply(a22, b21 - b11)
    m5 = strassen_matrix_multiply(a11 + a22, b11 + b22)
    m6 = strassen_matrix_multiply(a12 - a22, b21 + b22)
    m7 = strassen_matrix_multiply(a11 - a21, b11 + b12)

    # 计算结果矩阵的四个子矩阵
    c11 = m5 + m4 - m2 + m6
    c12 = m1 + m2
    c21 = m3 + m4
    c22 = m5 + m1 - m3 - m7

    # 组合四个子矩阵形成结果矩阵
    c = np.zeros((n, n))

    c[:mid, :mid] = c11
    c[:mid, mid:] = c12
    c[mid:, :mid] = c21
    c[mid:, mid:] = c22

    return c


a = np.random.randint(0, 10, (256, 256))
b = np.random.randint(0, 10, (256, 256))

res = strassen_matrix_multiply(a, b)

print('矩阵 a:')
print(a)

print('\n矩阵 b:')
print(b)

print('\n乘积矩阵 c:')
print(res)
矩阵 a:
[[4 8 5 ... 6 3 7]
 [8 0 0 ... 8 4 6]
 [5 4 2 ... 7 4 8]
 ...
 [6 0 7 ... 5 5 6]
 [9 3 6 ... 6 2 9]
 [2 2 5 ... 7 3 6]]

矩阵 b:
[[8 6 5 ... 3 7 2]
 [2 9 5 ... 4 2 3]
 [1 0 0 ... 0 5 4]
 ...
 [0 6 4 ... 3 9 1]
 [9 9 7 ... 5 4 1]
 [6 0 1 ... 2 6 1]]

乘积矩阵 c:
[[5306. 5218. 5339. ... 5510. 5653. 5120.]
 [5327. 5038. 5228. ... 5132. 5251. 4969.]
 [5325. 5265. 5237. ... 5252. 5474. 4962.]
 ...
 [5491. 5176. 5190. ... 5443. 5463. 5040.]
 [5648. 5509. 5191. ... 5439. 5707. 5273.]
 [5034. 5027. 5061. ... 4866. 5366. 4772.]]

  • 35
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值