Strassen的思想是把矩阵分块,用分块矩阵的和来生成七个中间矩阵后,再做乘法,因此复杂度为n的lg7次方。注意它在计算非整数时,会有累积误差。目前最优的是Coppersmith和Winograd提出的算法,复杂度是n的2.376次方。
python代码
import numpy as np
def StrassenSquareMatrixMultiple(A,B):
n = A.shape[0]
if n == 1:
return np.dot(A,B)
S1 = B[0:n//2,n//2:n] - B[n//2:n,n//2:n]
S2 = A[0:n//2,0:n//2] + A[0:n//2,n//2:n]
S3 = A[n//2:n,0:n//2] + A[n//2:n,n//2:n]
S4 = B[n//2:n,0:n//2] - B[0:n//2,0:n//2]
S5 = A[0:n//2,0:n//2] + A[n//2:n,n//2:n]
S6 = B[0:n//2,0:n//2] + B[n//2:n,n//2:n]
S7 = A[0:n//2,n//2:n] - A[n//2:n,n//2:n]
S8 = B[n//2:n,0:n//2] + B[n//2:n,n//2:n]
S9 = A[0:n//2,0:n//2] - A[n//2:n,0:n//2]
S10 = B[0:n//2,0:n//2] + B[0:n//2,n//2:n]
P1 = StrassenSquareMatrixMultiple(A[0:n//2,0:n//2],S1)
P2 = StrassenSquareMatrixMultiple(S2,B[n//2:n,n//2:n])
P3 = StrassenSquareMatrixMultiple(S3,B[0:n//2,0:n//2])
P4 = StrassenSquareMatrixMultiple(A[n//2:n,n//2:n],S4)
P5 = StrassenSquareMatrixMultiple(S5,S6)
P6 = StrassenSquareMatrixMultiple(S7,S8)
P7 = StrassenSquareMatrixMultiple(S9,S10)
C = np.zeros([n,n])
C[0:n//2,0:n//2] = P5+P4-P2+P6
C[0:n//2,n//2:n] = P1 + P2
C[n//2:n,0:n//2] = P3+P4
C[n//2:n,n//2:n] = P5+P1-P3-P7
return C
A = np.array([[1,2],
[3,4]])
B = np.array([[5,6],
[7,8]])
print(StrassenSquareMatrixMultiple(A,B))