Strassen矩阵乘法

施特拉森算法(英语:Strassen algorithm)是一个计算矩阵乘法算法,时间复杂度为。 

简介

施特拉森算法在1969年由沃尔克·施特拉森(德语:Volker Strassen)提出来,是第一个时间复杂度低于矩阵乘法算法。由于算法简单理解,且为第一个被提出来的特性,常被算法教材拿来当作主定理(英语:Master theorem)计算时间复杂度的例子。

 

 

另外,因为施特拉森算法证明了矩阵乘法存在时间复杂度低于的算法,使得更多学者投入研究,寻找更快的算法。

 

评论

一般矩阵乘法的时间复杂度为,施特拉森算法因为只有每次的分治法(英语:Divide and conquer algorithm)只有七个矩阵乘法运算,所以依照主定理(英语:Master theorem)可以得出时间复杂度为{\displaystyle O(n^{\log _{2}7})=O(n^{2.807})}。但Strassen算法的数值稳定性较差。

现时时间复杂度最低的矩阵乘法算法是Coppersmith-Winograd方法的一种扩展方法,其算法复杂度为

 以上参考wiki

 代码实现

import random


def generateMatrix(n, start, end):
    matrix = []
    for i in range(n):
        raw = []
        for j in range(n):
            raw.append(random.randint(start, end))
        matrix.append(raw)
    return matrix


def departMatrix(a):
    a00 = []
    a01 = []
    a10 = []
    a11 = []
    n = len(a)
    for i in range(n):
        raw00 = []
        raw01 = []
        raw10 = []
        raw11 = []
        for j in range(n):
            if i < n / 2 and j < n / 2:
                raw00.append(a[i][j])
            elif i < n / 2 <= j:
                raw01.append(a[i][j])
            elif i >= n / 2 > j:
                raw10.append(a[i][j])
            else:
                raw11.append(a[i][j])
        if len(raw00) != 0:
            a00.append(raw00)
        if len(raw01) != 0:
            a01.append(raw01)
        if len(raw10) != 0:
            a10.append(raw10)
        if len(raw11) != 0:
            a11.append(raw11)
    return a00, a01, a10, a11


def matrixAdd(a, b, d='+'):
    c = []
    if d == '+':
        for i in range(len(a)):
            raw = []
            for j in range(len(a)):
                raw.append(a[i][j] + b[i][j])
            c.append(raw)
    elif d == '-':
        for i in range(len(a)):
            raw = []
            for j in range(len(a)):
                raw.append(a[i][j] - b[i][j])
            c.append(raw)
    return c


def matrixMerge(a00, a01, a10, a11):
    c = []
    n = len(a00)
    for i in range(n * 2):
        row = []
        if i < n:
            for j in range(n):
                row.append(a00[i][j])
            for j in range(n):
                row.append(a01[i][j])
        else:
            for j in range(n):
                row.append(a10[i - n][j])
            for j in range(n):
                row.append(a11[i - n][j])
        c.append(row)
    return c


def matrixNormalMultiply(a, b):
    n = len(a)
    c = []
    for k in range(n):
        raw = []
        for i in range(n):
            sum = 0
            for j in range(n):
                sum = sum + (a[k][j] * b[j][i])
            raw.append(sum)
        c.append(raw)
    return c


def strassenMatrixMultiply(a, b):
    if len(a) > 2 or len(b) > 2:
        a00, a01, a10, a11 = departMatrix(a)
        b00, b01, b10, b11 = departMatrix(b)
        m1 = strassenMatrixMultiply(matrixAdd(a00, a11), matrixAdd(b00, b11))
        m2 = strassenMatrixMultiply(matrixAdd(a10, a11), b00)
        m3 = strassenMatrixMultiply(a00, matrixAdd(b01, b11, '-'))
        m4 = strassenMatrixMultiply(a11, matrixAdd(b10, b00, '-'))
        m5 = strassenMatrixMultiply(matrixAdd(a00, a01), b11)
        m6 = strassenMatrixMultiply(matrixAdd(a10, a00, '-'), matrixAdd(b00, b01))
        m7 = strassenMatrixMultiply(matrixAdd(a01, a11, '-'), matrixAdd(b10, b11))
        c00 = matrixAdd(matrixAdd(matrixAdd(m1, m4), m5, '-'), m7)
        c01 = matrixAdd(m3, m5)
        c10 = matrixAdd(m2, m4)
        c11 = matrixAdd(matrixAdd(matrixAdd(m1, m3), m2, '-'), m6)
        return matrixMerge(c00, c01, c10, c11)
    elif len(a) == 2 and len(b) == 2:
        return matrixNormalMultiply(a, b)


def matrixShow(a):
    for i in range(len(a)):
        for j in range(len(a[0])):
            print(a[i][j], end=' ')
        print("")
    print('')


if __name__ == '__main__':
    a = generateMatrix(4, 10, 99)
    b = generateMatrix(4, 10, 99)
    matrixShow(a)
    matrixShow(b)
    matrixShow(strassenMatrixMultiply(a, b))

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值