公式
这个公式可比Strassen算法复杂多了啊。假设有以下两个矩阵,元素是任何代数环(实数、复数、函数、矩阵等):
A
=
[
A
11
A
12
A
21
A
22
]
B
=
[
B
11
B
12
B
21
B
22
]
A= \left[ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix}\right]\\ B= \left[ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{matrix}\right]\\
A=[A11A21A12A22]B=[B11B21B12B22]
相比于Strassen算法的7个变量,Winograd算法用到了36个变量,这记忆是肯定记不住的了。还是那句话,建议收藏本文,随时翻阅。算法公式如下:
S
1
=
A
21
+
A
22
S
2
=
S
1
−
A
11
S
3
=
A
11
−
A
21
S
4
=
A
12
−
S
2
T
1
=
B
12
−
B
11
T
2
=
B
22
−
T
1
T
3
=
B
22
−
B
12
T
4
=
T
2
−
B
21
M
1
=
A
11
B
11
M
2
=
A
12
B
21
M
3
=
S
4
B
22
M
4
=
A
22
T
4
M
5
=
S
1
T
1
M
6
=
S
2
T
2
M
7
=
S
3
T
3
U
1
=
M
1
+
M
2
U
2
=
M
1
+
M
6
U
3
=
U
2
+
M
7
U
4
=
U
2
+
M
5
U
5
=
U
2
+
M
3
U
6
=
M
3
−
M
4
U
7
=
U
3
+
M
5
A
B
=
(
U
1
U
5
U
6
U
7
)
S_1=A_{21}+A_{22}\\ S_2=S_1-A_{11}\\ S_3=A_{11}-A_{21}\\ S_4=A_{12}-S_2\\ T_1=B_{12}-B_{11}\\ T_2=B_{22}-T_1\\ T_3=B_{22}-B_{12}\\ T_4=T_2-B_{21}\\ M_1=A_{11}B_{11}\\ M_2=A_{12}B_{21}\\ M_3=S_4B_{22}\\ M_4=A_{22}T_4\\ M_5=S_1T_1\\ M_6=S_2T_2\\ M_7=S_3T_3\\ U_1=M_1+M_2\\ U_2=M_1+M_6\\ U_3=U_2+M_7\\ U_4=U_2+M_5\\ U_5=U_2+M_3\\ U_6=M_3-M_4\\ U_7=U_3+M_5\\ AB= \begin{pmatrix} U_1 & U_5\\ U_6 & U_7 \end{pmatrix}\\
S1=A21+A22S2=S1−A11S3=A11−A21S4=A12−S2T1=B12−B11T2=B22−T1T3=B22−B12T4=T2−B21M1=A11B11M2=A12B21M3=S4B22M4=A22T4M5=S1T1M6=S2T2M7=S3T3U1=M1+M2U2=M1+M6U3=U2+M7U4=U2+M5U5=U2+M3U6=M3−M4U7=U3+M5AB=(U1U6U5U7)
与Strassen算法相比,只用了7次乘法和15次加减法。与其他张量类的算法相比,更容易理解,学习难度小多了。
Python实现
class Matrix:
# 矩阵
@staticmethod
def create_by_lines(lines):
# 为了支持分块,设置四个属性
return Matrix(lines, 0, len(lines), 0, len(lines[0]))
def __init__(self, lines, row_start, row_end, column_start, column_end):
self.__lines = lines
# 为了支持分块,设置四个属性
self.__column_start = column_start
self.__column_end = column_end
self.__row_start = row_start
self.__row_end = row_end
def __mul__(self, other):
# 首先判断能不能相乘
if self.column_len() != other.row_len():
raise Exception("矩阵A列数%d != 矩阵B的行数%d" % (len(self.__lines[0]), len(other.__lines)))
# 然后判断是不是2X2矩阵
# 这里场景比较多:
# 1 1 x n n x 1
# 2 n x 1 1 x n
# 3 2 x 2 2 x 2 strassen 数值运算
# 4 其他,进行分块 strassen 矩阵运算
if self.row_len() == 1 or self.column_len() == 1:
return self.plain_mul(other)
# 奇数不能分块
if self.row_len() & 1 == 1 or self.column_len() & 1 == 1 or other.row_len() & 1 == 1:
return self.plain_mul(other)
# 这个时候就可以使用strassen算法了
a11, a12, a21, a22 = self.sub()
b11, b12, b21, b22 = other.sub()
s1 = a21 + a22
s2 = s1 - a11
s3 = a11 - a21
s4 = a12 - s2
t1 = b12 - b11
t2 = b22 - t1
t3 = b22 - b12
t4 = t2 - b21
m1 = a11 * b11
m2 = a12 * b21
m3 = s4 * b22
m4 = a22 * t4
m5 = s1 * t1
m6 = s2 * t2
m7 = s3 * t3
u1 = m1 + m2
u2 = m1 + m6
u3 = u2 + m7
u4 = u2 + m5
u5 = u4 + m3
u6 = u3 - m4
u7 = u3 + m5
return Matrix.create(u1, u5, u6, u7)
def __add__(self, other):
arr = [[0] * self.column_len() for _ in range(0, self.row_len())]
# 里面不能是同一个数组
for i in range(0, self.row_len()):
self_row = self.__lines[self.__row_start + i]
other_row = other.__lines[other.__row_start + i]
for j in range(0, self.column_len()):
arr[i][j] = self_row[self.__column_start + j] + other_row[other.__column_start + j]
return Matrix.create_by_lines(arr)
def __sub__(self, other):
arr = [[0] * self.column_len() for _ in range(0, self.row_len())]
# 里面不能是同一个数组
for i in range(0, self.row_len()):
self_row = self.__lines[self.__row_start + i]
other_row = other.__lines[other.__row_start + i]
for j in range(0, self.column_len()):
arr[i][j] = self_row[self.__column_start + j] - other_row[other.__column_start + j]
return Matrix.create_by_lines(arr)
def plain_mul(self, other):
# 弄一个m行n列的新矩阵
m = self.row_len()
n = other.column_len()
p = other.row_len()
result = [[0] * n for _ in range(0, m)]
# i 代表 A矩阵的行
for i in range(self.__row_start, self.__row_end):
# j 代表 B 矩阵的列
for j in range(other.__column_start, other.__column_end):
# 第一个矩阵的行 与第二个矩阵列的乘积和
# k 代表 A矩阵的列和B矩阵的行
for k in range(0, p):
self_line = self.__lines[i]
other_line = other.__lines[other.__row_start + k]
a = self_line[self.__column_start + k]
b = other_line[j]
mul = a * b
result[i - self.__row_start][j - other.__column_start] += mul
return Matrix.create_by_lines(result)
def row_len(self):
return self.__row_end - self.__row_start
def column_len(self):
return self.__column_end - self.__column_start
def sub(self):
a_middle_row = (self.__row_end + self.__row_start) // 2
a_middle_column = (self.__column_end + self.__column_start) // 2
a11 = Matrix(self.__lines, self.__row_start, a_middle_row, self.__column_start, a_middle_column)
a12 = Matrix(self.__lines, self.__row_start, a_middle_row, a_middle_column, self.__column_end)
a21 = Matrix(self.__lines, a_middle_row, self.__row_end, self.__column_start, a_middle_column)
a22 = Matrix(self.__lines, a_middle_row, self.__row_end, a_middle_column, self.__column_end)
return a11, a12, a21, a22
@staticmethod
def create(a11, a12, a21, a22):
len_rows = a11.row_len() + a21.row_len()
len_columns = a11.column_len() + a12.column_len()
lines = [[0] * len_columns for _ in range(0, len_rows)]
# 拷贝进去
a11.copy_to(lines, 0, 0)
a12.copy_to(lines, 0, a11.column_len())
a21.copy_to(lines, a11.row_len(), 0)
a22.copy_to(lines, a12.row_len(), a21.column_len())
return Matrix.create_by_lines(lines)
def copy_to(self, lines, row_start, column_start):
for i in range(0, self.row_len()):
self_row = self.__lines[self.__row_start + i]
other_row = lines[row_start + i]
for j in range(0, self.column_len()):
other_row[column_start + j] = self_row[self.__column_start + j]
@property
def lines(self):
return self.__lines
Python测试代码
写完一定要测试,否则发现不了其中的bug啊。在某拉维夫大学的课件中就有以下一处错误:
以下是我的自测代码:
# _*_ coding:utf-8 _*_
from com.youngthing.mathalgorithm.matrix import Matrix
from com.youngthing.mathalgorithm.linearalgebra.winograd import Matrix as WinogradMatrix
if __name__ == '__main__':
lines_a = [[1, 2, ], [4, 5, ], ] # 4x3
lines_b = [[2, 1, ], [3, 6, ]] # 3x4
a = Matrix(lines_a)
b = Matrix(lines_b)
print(a * b)
print(WinogradMatrix.create_by_lines(lines_a) * WinogradMatrix.create_by_lines(lines_b))
c = WinogradMatrix.create_by_lines(lines_a) * WinogradMatrix.create_by_lines(lines_b)
print(c.lines)
lines_a = [[1, 2, 3, 5], [4, 5, 6, 6], [7, 8, 9, 7], [1, 2, 3, 8]] # 4x3
lines_b = [[2, 1, 5, 7], [3, 6, 9, -1], [4, 8, -2, -3], [-4, -8, -5, -6]] # 3x4
a = Matrix(lines_a)
b = Matrix(lines_b)
print(a * b)
c = WinogradMatrix.create_by_lines(lines_a) * WinogradMatrix.create_by_lines(lines_b)
print(c.lines)
矩阵乘法的算法其实还有性能更好的,比如京都大学的La Gall在2014年发现的算法,但因为太复杂了,所以我就不写了,有兴趣的可以去搜索相关资料。接下来我要介绍与标准矩阵乘法不一样的另外两种矩阵乘法, Kronecker积与阿达马积。