矩阵乘法的Strassen算法
求解两个矩阵的点乘积的算法
简单算法
时间复杂度为:Θ(n3)
简单的分治算法
假设矩阵的行列数相等的方阵,并且为2的幂。
递归地将两个矩阵分别平均分解为行列数为(n/2,n/2)的四个方阵。
已知:
直接的递归算法伪代码:
python实现(方阵的乘积):
import sys
import numpy as np
# index = 0
def multi_matrix(A,B):
n= A.shape[0]
if n == 1:
C = A[0][0]*B[0][0]
return C
else:
i = int(n/2)
a11 = np.array([[col for col in row[:i]] for row in A[:i]])
a12 = np.array([[col for col in row[i:]] for row in A[:i]])
a21 = np.array([[col for col in row[:i]] for row in A[i:]])
a22 = np.array([[col for col in row[i:]] for row in A[i:]])
b11 = np.array([[col for col in row[:i]] for row in B[:i]])
b12 = np.array([[col for col in row[i:]] for row in B[:i]])
b21 = np.array([[col for col in row[:i]] for row in B[i:]])
b22 = np.array([[col for col in row[i:]] for row in B[i:]])
c11 = multi_matrix(a11,b11)+multi_matrix(a12,b21)
c12 = multi_matrix(a11,b12)+multi_matrix(a12,b22)
c21 = multi_matrix(a21,b11)+multi_matrix(a22,b21)
c22 = multi_matrix(a21,b12)+multi_matrix(a22,b22)
c1 = np.hstack((c11,c12))
c2 = np.hstack((c21,c22))
C = np.vstack((c1,c2))
return C
if __name__ == "__main__":
# A = list(map(int, sys.stdin.readline().strip().split(' ')))
A = np.random.randint(0, 20, size=[4, 4])
B = np.random.randint(0, 20, size=[4, 4])
C = multi_matrix(A,B)
print(A,'\n',B,'\n',C)
时间复杂度分析:
- 对于n=1,T(1) = Θ(1)
- 6-9行,共8次调用,8T(n/2)
- 6-9行,执行加法的次数为4次,复杂度为Θ(n2)
- 所以递归式:
求解后发现,算法的时间复杂度仍为Θ(n3),并没有优于简单算法
Strassen方法
思想是降低递归的次数,8→7
步骤:
1.分解矩阵,Θ(1)
2 创建10个(n/2,n/2)的矩阵S1–S10,Θ(n2)
3 用1和2中产生的矩阵,递归计算矩阵积P1–P7
4 通过P矩阵的不同组合进行加减运算,得到最终结果。Θ(n2)
递归式:
求得时间复杂度为:Θ(n(lg7))
S1-S10:
P1-P7:
C11 = P5+P4-P2+P6
C12 = P1+P2
C21 = P3+P4
C22 = P5+P1-P3-P7
python 实现:
def strassen(A,B):
n = A.shape[0]
if n == 1:
C = A[0][0] * B[0][0]
return C
else:
i = int(n / 2)
a11 = np.array([[col for col in row[:i]] for row in A[:i]])
a12 = np.array([[col for col in row[i:]] for row in A[:i]])
a21 = np.array([[col for col in row[:i]] for row in A[i:]])
a22 = np.array([[col for col in row[i:]] for row in A[i:]])
b11 = np.array([[col for col in row[:i]] for row in B[:i]])
b12 = np.array([[col for col in row[i:]] for row in B[:i]])
b21 = np.array([[col for col in row[:i]] for row in B[i:]])
b22 = np.array([[col for col in row[i:]] for row in B[i:]])
s1 = b12-b22
s2 = a11 + a12
s3 = a21+a22
s4 = b21-b11
s5 = a11+a22
s6 = b11+b22
s7 = a12-a22
s8 = b21+b22
s9 = a11-a21
s10 = b11+b12
p1 = strassen(a11,s1)
p2 = strassen(s2,b22)
p3 = strassen(s3,b11)
p4 = strassen(a22,s4)
p5 = strassen(s5,s6)
p6 = strassen(s7,s8)
p7 = strassen(s9,s10)
c11 = p5+p4-p2+p6
c12 = p1+p2
c21 = p3+p4
c22 = p5+p1-p3-p7
c1 = np.hstack((c11, c12))
c2 = np.hstack((c21, c22))
C = np.vstack((c1, c2))
return C