1.3 矩阵相乘Winograd算法

公式

  这个公式可比Strassen算法复杂多了啊。假设有以下两个矩阵,元素是任何代数环(实数、复数、函数、矩阵等):
A = [ A 11 A 12 A 21 A 22 ] B = [ B 11 B 12 B 21 B 22 ] A= \left[ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix}\right]\\ B= \left[ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{matrix}\right]\\ A=[A11A21A12A22]B=[B11B21B12B22]
  相比于Strassen算法的7个变量,Winograd算法用到了36个变量,这记忆是肯定记不住的了。还是那句话,建议收藏本文,随时翻阅。算法公式如下:
S 1 = A 21 + A 22 S 2 = S 1 − A 11 S 3 = A 11 − A 21 S 4 = A 12 − S 2 T 1 = B 12 − B 11 T 2 = B 22 − T 1 T 3 = B 22 − B 12 T 4 = T 2 − B 21 M 1 = A 11 B 11 M 2 = A 12 B 21 M 3 = S 4 B 22 M 4 = A 22 T 4 M 5 = S 1 T 1 M 6 = S 2 T 2 M 7 = S 3 T 3 U 1 = M 1 + M 2 U 2 = M 1 + M 6 U 3 = U 2 + M 7 U 4 = U 2 + M 5 U 5 = U 2 + M 3 U 6 = M 3 − M 4 U 7 = U 3 + M 5 A B = ( U 1 U 5 U 6 U 7 ) S_1=A_{21}+A_{22}\\ S_2=S_1-A_{11}\\ S_3=A_{11}-A_{21}\\ S_4=A_{12}-S_2\\ T_1=B_{12}-B_{11}\\ T_2=B_{22}-T_1\\ T_3=B_{22}-B_{12}\\ T_4=T_2-B_{21}\\ M_1=A_{11}B_{11}\\ M_2=A_{12}B_{21}\\ M_3=S_4B_{22}\\ M_4=A_{22}T_4\\ M_5=S_1T_1\\ M_6=S_2T_2\\ M_7=S_3T_3\\ U_1=M_1+M_2\\ U_2=M_1+M_6\\ U_3=U_2+M_7\\ U_4=U_2+M_5\\ U_5=U_2+M_3\\ U_6=M_3-M_4\\ U_7=U_3+M_5\\ AB= \begin{pmatrix} U_1 & U_5\\ U_6 & U_7 \end{pmatrix}\\ S1=A21+A22S2=S1A11S3=A11A21S4=A12S2T1=B12B11T2=B22T1T3=B22B12T4=T2B21M1=A11B11M2=A12B21M3=S4B22M4=A22T4M5=S1T1M6=S2T2M7=S3T3U1=M1+M2U2=M1+M6U3=U2+M7U4=U2+M5U5=U2+M3U6=M3M4U7=U3+M5AB=(U1U6U5U7)
  与Strassen算法相比,只用了7次乘法和15次加减法。与其他张量类的算法相比,更容易理解,学习难度小多了。

Python实现

class Matrix:
    # 矩阵
    @staticmethod
    def create_by_lines(lines):
        # 为了支持分块,设置四个属性
        return Matrix(lines, 0, len(lines), 0, len(lines[0]))

    def __init__(self, lines, row_start, row_end, column_start, column_end):
        self.__lines = lines
        # 为了支持分块,设置四个属性
        self.__column_start = column_start
        self.__column_end = column_end
        self.__row_start = row_start
        self.__row_end = row_end

    def __mul__(self, other):
        # 首先判断能不能相乘
        if self.column_len() != other.row_len():
            raise Exception("矩阵A列数%d != 矩阵B的行数%d" % (len(self.__lines[0]), len(other.__lines)))
        # 然后判断是不是2X2矩阵
        # 这里场景比较多:
        # 1 1 x n n x 1
        # 2 n x 1 1 x n
        # 3 2 x 2 2 x 2 strassen 数值运算
        # 4 其他,进行分块 strassen 矩阵运算
        if self.row_len() == 1 or self.column_len() == 1:
            return self.plain_mul(other)

        # 奇数不能分块
        if self.row_len() & 1 == 1 or self.column_len() & 1 == 1 or other.row_len() & 1 == 1:
            return self.plain_mul(other)

        # 这个时候就可以使用strassen算法了

        a11, a12, a21, a22 = self.sub()
        b11, b12, b21, b22 = other.sub()

        s1 = a21 + a22
        s2 = s1 - a11
        s3 = a11 - a21
        s4 = a12 - s2

        t1 = b12 - b11
        t2 = b22 - t1
        t3 = b22 - b12
        t4 = t2 - b21

        m1 = a11 * b11
        m2 = a12 * b21
        m3 = s4 * b22
        m4 = a22 * t4
        m5 = s1 * t1
        m6 = s2 * t2
        m7 = s3 * t3

        u1 = m1 + m2
        u2 = m1 + m6
        u3 = u2 + m7
        u4 = u2 + m5
        u5 = u4 + m3
        u6 = u3 - m4
        u7 = u3 + m5

        return Matrix.create(u1, u5, u6, u7)

    def __add__(self, other):
        arr = [[0] * self.column_len() for _ in range(0, self.row_len())]
        # 里面不能是同一个数组
        for i in range(0, self.row_len()):
            self_row = self.__lines[self.__row_start + i]
            other_row = other.__lines[other.__row_start + i]
            for j in range(0, self.column_len()):
                arr[i][j] = self_row[self.__column_start + j] + other_row[other.__column_start + j]
        return Matrix.create_by_lines(arr)

    def __sub__(self, other):
        arr = [[0] * self.column_len() for _ in range(0, self.row_len())]
        # 里面不能是同一个数组
        for i in range(0, self.row_len()):
            self_row = self.__lines[self.__row_start + i]
            other_row = other.__lines[other.__row_start + i]
            for j in range(0, self.column_len()):
                arr[i][j] = self_row[self.__column_start + j] - other_row[other.__column_start + j]
        return Matrix.create_by_lines(arr)

    def plain_mul(self, other):
        # 弄一个m行n列的新矩阵
        m = self.row_len()
        n = other.column_len()
        p = other.row_len()

        result = [[0] * n for _ in range(0, m)]
        # i 代表 A矩阵的行
        for i in range(self.__row_start, self.__row_end):
            # j 代表 B 矩阵的列
            for j in range(other.__column_start, other.__column_end):
                # 第一个矩阵的行 与第二个矩阵列的乘积和
                # k 代表 A矩阵的列和B矩阵的行
                for k in range(0, p):
                    self_line = self.__lines[i]
                    other_line = other.__lines[other.__row_start + k]
                    a = self_line[self.__column_start + k]
                    b = other_line[j]
                    mul = a * b
                    result[i - self.__row_start][j - other.__column_start] += mul
        return Matrix.create_by_lines(result)

    def row_len(self):
        return self.__row_end - self.__row_start

    def column_len(self):
        return self.__column_end - self.__column_start

    def sub(self):
        a_middle_row = (self.__row_end + self.__row_start) // 2
        a_middle_column = (self.__column_end + self.__column_start) // 2
        a11 = Matrix(self.__lines, self.__row_start, a_middle_row, self.__column_start, a_middle_column)
        a12 = Matrix(self.__lines, self.__row_start, a_middle_row, a_middle_column, self.__column_end)
        a21 = Matrix(self.__lines, a_middle_row, self.__row_end, self.__column_start, a_middle_column)
        a22 = Matrix(self.__lines, a_middle_row, self.__row_end, a_middle_column, self.__column_end)
        return a11, a12, a21, a22

    @staticmethod
    def create(a11, a12, a21, a22):
        len_rows = a11.row_len() + a21.row_len()
        len_columns = a11.column_len() + a12.column_len()
        lines = [[0] * len_columns for _ in range(0, len_rows)]
        # 拷贝进去
        a11.copy_to(lines, 0, 0)
        a12.copy_to(lines, 0, a11.column_len())
        a21.copy_to(lines, a11.row_len(), 0)
        a22.copy_to(lines, a12.row_len(), a21.column_len())
        return Matrix.create_by_lines(lines)

    def copy_to(self, lines, row_start, column_start):
        for i in range(0, self.row_len()):
            self_row = self.__lines[self.__row_start + i]
            other_row = lines[row_start + i]
            for j in range(0, self.column_len()):
                other_row[column_start + j] = self_row[self.__column_start + j]

    @property
    def lines(self):
        return self.__lines

Python测试代码

  写完一定要测试,否则发现不了其中的bug啊。在某拉维夫大学的课件中就有以下一处错误:
在这里插入图片描述

  以下是我的自测代码:

# _*_ coding:utf-8 _*_
from com.youngthing.mathalgorithm.matrix import Matrix
from com.youngthing.mathalgorithm.linearalgebra.winograd import Matrix as WinogradMatrix

if __name__ == '__main__':
    lines_a = [[1, 2, ], [4, 5, ], ]  # 4x3
    lines_b = [[2, 1, ], [3, 6, ]]  # 3x4
    a = Matrix(lines_a)
    b = Matrix(lines_b)
    print(a * b)
    print(WinogradMatrix.create_by_lines(lines_a) * WinogradMatrix.create_by_lines(lines_b))

    c = WinogradMatrix.create_by_lines(lines_a) * WinogradMatrix.create_by_lines(lines_b)
    print(c.lines)

    lines_a = [[1, 2, 3, 5], [4, 5, 6, 6], [7, 8, 9, 7], [1, 2, 3, 8]]  # 4x3
    lines_b = [[2, 1, 5, 7], [3, 6, 9, -1], [4, 8, -2, -3], [-4, -8, -5, -6]]  # 3x4
    a = Matrix(lines_a)
    b = Matrix(lines_b)
    print(a * b)
    c = WinogradMatrix.create_by_lines(lines_a) * WinogradMatrix.create_by_lines(lines_b)
    print(c.lines)

  矩阵乘法的算法其实还有性能更好的,比如京都大学的La Gall在2014年发现的算法,但因为太复杂了,所以我就不写了,有兴趣的可以去搜索相关资料。接下来我要介绍与标准矩阵乘法不一样的另外两种矩阵乘法, Kronecker积阿达马积

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

醒过来摸鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值