Strassen 矩阵乘法

Strassen 矩阵乘法

矩阵大小限制为2的n次方

矩阵大小限制为2的n次方

# -*- coding: utf-8 -*-            
# @Time : 2022-03-10 20:55
# @Author: Fu Lebing
# @FileName: Strassen.py
# @Software: PyCharm
import numpy as np


def matrix_minus(Matrix_A, Matrix_B):  # 矩阵减法
    rows = len(Matrix_A)
    columns = len(Matrix_A[0])
    Matrix_C = [list() for i in range(rows)]
    for i in range(rows):
        for j in range(columns):
            Matrix_C[i].append(Matrix_A[i][j] - Matrix_B[i][j])
    return Matrix_C


def matrix_add(Matrix_A, Matrix_B):  # 矩阵加法
    rows = len(Matrix_A)
    columns = len(Matrix_A[0])
    Matrix_C = [list() for i in range(rows)]
    for i in range(rows):
        for j in range(columns):
            Matrix_C[i].append(Matrix_A[i][j] + Matrix_B[i][j])
    return Matrix_C


def matrix_divide(Matrix_A, row, column):
    # 将矩阵A划分为四个大小相同的矩阵, 第一个矩阵的位置是[row, column] = [1, 1],
    # 第二个[row, column] = [1, 2], 第三个[row, column] = [2, 1], 第四个[row, column] = [2, 2]

    length = len(Matrix_A) // 2  # length为划分后的矩阵大小
    Matrix_B = [list() for i in range(length)]
    k = 0
    for i in range((row - 1) * length, row * length):
        for j in range((column - 1) * length, column * length):
            Matrix_B[k].append(Matrix_A[i][j])
        k += 1
    return Matrix_B


def matrix_merge(Matrix_11, Matrix_12, Matrix_21, Matrix_22):# 拼接四个小矩阵成为一个大矩阵
    length = len(Matrix_11)
    Matrix_All = [list() for i in range(length * 2)]  # 拼接后的矩阵的长宽为小矩阵的2倍
    for i in range(length):
        Matrix_All[i] = Matrix_11[i] + Matrix_12[i]
    for j in range(length):
        Matrix_All[j + length] = Matrix_21[j] + Matrix_22[j]
    return Matrix_All


def Strassen(Matrix_A, Matrix_B):
    if len(Matrix_A) == 1:
        matrix_all = [list() for i in range(1)]
        matrix_all[0].append(Matrix_A[0][0] * Matrix_B[0][0])
    else:
        MatLength = len(Matrix_A)
        NextMatLength = MatLength / 2
        a00 = matrix_divide(Matrix_A, 1, 1)
        a01 = matrix_divide(Matrix_A, 1, 2)
        a10 = matrix_divide(Matrix_A, 2, 1)
        a11 = matrix_divide(Matrix_A, 2, 2)
        b00 = matrix_divide(Matrix_B, 1, 1)
        b01 = matrix_divide(Matrix_B, 1, 2)
        b10 = matrix_divide(Matrix_B, 2, 1)
        b11 = matrix_divide(Matrix_B, 2, 2)
        m1 = Strassen((matrix_add(a00, a11)), matrix_add(b00, b11))
        m2 = Strassen(matrix_add(a10, a11), b00)
        m3 = Strassen(a00, matrix_minus(b01, b11))
        m4 = Strassen(a11, matrix_minus(b10, b00))
        m5 = Strassen(matrix_add(a00, a01), b11)
        m6 = Strassen(matrix_minus(a10, a00), matrix_add(b00, b01))
        m7 = Strassen(matrix_minus(a01, a11), matrix_add(b10, b11))
        matrix_all = matrix_merge(matrix_add(matrix_add(m1, m4), matrix_minus(m7, m5)), matrix_add(m3, m5), matrix_add(m2, m4), matrix_add(matrix_add(m1, m3), matrix_minus(m6, m2)))

    return matrix_all
def testnumpy():
    arr1 = np.random.randint(0, 10, size=(16, 16))  # strassen的矩阵大小限定2的n次方
    arr2 = np.random.randint(10, 20, size=(16, 16))
    # arr1 = [[1, 2], [3, 4]]
    # arr2 = [[1, 2], [3, 4]]
    result_arr = np.dot(arr1, arr2)
    return arr1, arr2, result_arr

if __name__ == '__main__':
    arr1, arr2, result_numpy = testnumpy()
    result_Strassen = Strassen(arr1, arr2)
    print(result_Strassen)
    print(result_numpy)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值