n阶方阵乘法straseen

原理:分块矩阵乘法,进行8次矩阵乘法,时间复杂度为 $\theta(n^3) = \theta(n^{\lg{8}}) $ , 改进后仅需要7次乘法, 时间复杂度为 \(\theta(n^{\lg{7}})\)
具体推到见算法导论中利用主定理推导时间复杂度

def matrix_divide(A):
    rows = len(A)
    mid = rows // 2
    A11 = [[0]*mid for _ in range(mid)]
    A12 = [[0]*mid for _ in range(mid)]
    A21 = [[0]*mid for _ in range(mid)]
    A22 = [[0]*mid for _ in range(mid)]

    for i in range(mid):
        for j in range(mid):
            A11[i][j] = A[i][j]
            A12[i][j] = A[i][mid+j]
            A21[i][j] = A[mid+i][j]
            A22[i][j] = A[mid+i][mid+j]
    return A11, A12, A21, A22

def matrix_add(A, B):
    rows = len(A)
    C = [[0]*rows for _ in range(rows)]
    for i in range(rows):
        for j in range(rows):
            C[i][j] = A[i][j] + B[i][j]
    return C

def matrix_sub(A, B):
    rows = len(A)
    C = [[0]*rows for _ in range(rows)]
    for i in range(rows):
        for j in range(rows):
            C[i][j] = A[i][j] - B[i][j]
    return C


def matrix_merge(C11, C12, C21, C22):
    rows = len(C11)
    n = rows * 2
    C = [[0]*n for _ in range(n)]
    for i in range(rows):
        for j in range(rows):
            C[i][j] = C11[i][j]
            C[i][rows+j] = C12[i][j]
            C[rows+i][j] = C21[i][j]
            C[rows+i][rows+j] = C22[i][j]
    return C


def strassen(A, B):
    n = len(A)
    C = [[0] for _ in range(n)]
    if n == 1:
        C[0][0] = A[0][0]*B[0][0]
        return C
    A11, A12, A21, A22 = matrix_divide(A)
    B11, B12, B21, B22 = matrix_divide(B)

    S1 = matrix_sub(B12, B22)
    S2 = matrix_add(A11, A12)
    S3 = matrix_add(A21, A22)
    S4 = matrix_sub(B21, B11)
    S5 = matrix_add(A11, A22)
    S6 = matrix_add(B11, B22)
    S7 = matrix_sub(A12, A22)
    S8 = matrix_add(B21, B22)
    S9 = matrix_sub(A11, A21)
    S10 = matrix_add(B11, B12)
    
    P1 = strassen(A11, S1)
    P2 = strassen(S2, B22)
    P3 = strassen(S3, B11)
    P4 = strassen(A22, S4)
    P5 = strassen(S5, S6)
    P6 = strassen(S7, S8)
    P7 = strassen(S9, S10)

    C11 = matrix_add(P5, matrix_sub(P4, matrix_sub(P2, P6)))
    C12 = matrix_add(P1, P2)
    C21 = matrix_add(P3, P4)
    C22 = matrix_add(P5, matrix_sub(P1, matrix_add(P3, P7)))
    
    return matrix_merge(C11, C12, C21, C22)
def main():
    A = [[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
    B = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]]
    C = strassen(A, B)
    print(C)
if __name__ == '__main__':
    main()

转载于:https://www.cnblogs.com/vito_wang/p/10806816.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值