施特拉森算法(英语:Strassen algorithm)是一个计算矩阵乘法的算法,时间复杂度为。
简介
施特拉森算法在1969年由沃尔克·施特拉森(德语:Volker Strassen)提出来,是第一个时间复杂度低于的矩阵乘法算法。由于算法简单理解,且为第一个被提出来的特性,常被算法教材拿来当作主定理(英语:Master theorem)计算时间复杂度的例子。
另外,因为施特拉森算法证明了矩阵乘法存在时间复杂度低于的算法,使得更多学者投入研究,寻找更快的算法。
评论
一般矩阵乘法的时间复杂度为,施特拉森算法因为只有每次的分治法(英语:Divide and conquer algorithm)只有七个矩阵乘法运算,所以依照主定理(英语:Master theorem)可以得出时间复杂度为。但Strassen算法的数值稳定性较差。
现时时间复杂度最低的矩阵乘法算法是Coppersmith-Winograd方法的一种扩展方法,其算法复杂度为。
以上参考wiki
代码实现
import random
def generateMatrix(n, start, end):
matrix = []
for i in range(n):
raw = []
for j in range(n):
raw.append(random.randint(start, end))
matrix.append(raw)
return matrix
def departMatrix(a):
a00 = []
a01 = []
a10 = []
a11 = []
n = len(a)
for i in range(n):
raw00 = []
raw01 = []
raw10 = []
raw11 = []
for j in range(n):
if i < n / 2 and j < n / 2:
raw00.append(a[i][j])
elif i < n / 2 <= j:
raw01.append(a[i][j])
elif i >= n / 2 > j:
raw10.append(a[i][j])
else:
raw11.append(a[i][j])
if len(raw00) != 0:
a00.append(raw00)
if len(raw01) != 0:
a01.append(raw01)
if len(raw10) != 0:
a10.append(raw10)
if len(raw11) != 0:
a11.append(raw11)
return a00, a01, a10, a11
def matrixAdd(a, b, d='+'):
c = []
if d == '+':
for i in range(len(a)):
raw = []
for j in range(len(a)):
raw.append(a[i][j] + b[i][j])
c.append(raw)
elif d == '-':
for i in range(len(a)):
raw = []
for j in range(len(a)):
raw.append(a[i][j] - b[i][j])
c.append(raw)
return c
def matrixMerge(a00, a01, a10, a11):
c = []
n = len(a00)
for i in range(n * 2):
row = []
if i < n:
for j in range(n):
row.append(a00[i][j])
for j in range(n):
row.append(a01[i][j])
else:
for j in range(n):
row.append(a10[i - n][j])
for j in range(n):
row.append(a11[i - n][j])
c.append(row)
return c
def matrixNormalMultiply(a, b):
n = len(a)
c = []
for k in range(n):
raw = []
for i in range(n):
sum = 0
for j in range(n):
sum = sum + (a[k][j] * b[j][i])
raw.append(sum)
c.append(raw)
return c
def strassenMatrixMultiply(a, b):
if len(a) > 2 or len(b) > 2:
a00, a01, a10, a11 = departMatrix(a)
b00, b01, b10, b11 = departMatrix(b)
m1 = strassenMatrixMultiply(matrixAdd(a00, a11), matrixAdd(b00, b11))
m2 = strassenMatrixMultiply(matrixAdd(a10, a11), b00)
m3 = strassenMatrixMultiply(a00, matrixAdd(b01, b11, '-'))
m4 = strassenMatrixMultiply(a11, matrixAdd(b10, b00, '-'))
m5 = strassenMatrixMultiply(matrixAdd(a00, a01), b11)
m6 = strassenMatrixMultiply(matrixAdd(a10, a00, '-'), matrixAdd(b00, b01))
m7 = strassenMatrixMultiply(matrixAdd(a01, a11, '-'), matrixAdd(b10, b11))
c00 = matrixAdd(matrixAdd(matrixAdd(m1, m4), m5, '-'), m7)
c01 = matrixAdd(m3, m5)
c10 = matrixAdd(m2, m4)
c11 = matrixAdd(matrixAdd(matrixAdd(m1, m3), m2, '-'), m6)
return matrixMerge(c00, c01, c10, c11)
elif len(a) == 2 and len(b) == 2:
return matrixNormalMultiply(a, b)
def matrixShow(a):
for i in range(len(a)):
for j in range(len(a[0])):
print(a[i][j], end=' ')
print("")
print('')
if __name__ == '__main__':
a = generateMatrix(4, 10, 99)
b = generateMatrix(4, 10, 99)
matrixShow(a)
matrixShow(b)
matrixShow(strassenMatrixMultiply(a, b))