Strassen矩阵乘法
题目:
设A和B是两个 n*n
矩阵,求它们的乘积 C=AB
。
分析:
方法一:
最原始的暴力解法,
c
i
j
=
∑
k
=
1
n
a
i
k
b
k
j
c_{ij}=\sum_{k=1}^na_{ik}b_{kj}
cij=k=1∑naikbkj
依此定义来计算 A
和 B
的乘积矩阵 C
,可知 C
中一个元素需要经过 n
次乘法和 n-1
次加法,那么 C
中有
n
2
n^2
n2 个元素,因此时间复杂度
T
(
n
)
=
O
(
n
3
)
T(n)=O(n^3)
T(n)=O(n3) 。
方法二:
Strassen矩阵乘法,类似大整数乘法的分治思想。将矩阵A和B从行列n/2处进行分块,则C也可以写成对应的分块矩阵形式,如:
[ C 11 C 12 C 21 C 22 ] = [ A 11 A 12 A 21 A 22 ] [ B 11 B 12 B 21 B 22 ] \begin{bmatrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{bmatrix} = \begin{bmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{bmatrix} \begin{bmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{bmatrix} [C11C21C12C22]=[A11A21A12A22][B11B21B12B22]
由此可得:
C
11
=
A
11
B
11
+
A
12
B
21
C
12
=
A
11
B
12
+
A
12
B
22
C
21
=
A
21
B
11
+
A
22
B
21
C
22
=
A
21
B
12
+
A
22
B
22
C_{11}=A_{11}B_{11}+A_{12}B_{21} \\ C_{12}=A_{11}B_{12}+A_{12}B_{22} \\ C_{21}=A_{21}B_{11}+A_{22}B_{21} \\ C_{22}=A_{21}B_{12}+A_{22}B_{22} \\
C11=A11B11+A12B21C12=A11B12+A12B22C21=A21B11+A22B21C22=A21B12+A22B22
时间复杂度表达式如下:
T
(
n
)
=
{
O
(
1
)
,
n
=
2
8
T
(
n
/
2
)
+
O
(
n
2
)
,
n
>
2
T(n)= \begin{cases} O(1), && n=2 \\ 8T(n/2)+O(n^2), && n>2 \end{cases}
T(n)={O(1),8T(n/2)+O(n2),n=2n>2
由递归式可知时间复杂度
T
(
n
)
=
O
(
n
3
)
T(n)=O(n^3)
T(n)=O(n3) ,并没有比方法一更高效。类似大整数乘法,Strassen设计了一组运算规则来减少子问题的个数,如下:
M
1
=
A
11
(
B
12
−
B
22
)
M
2
=
(
A
11
+
A
12
)
B
22
M
3
=
(
A
21
+
A
22
)
B
11
M
4
=
A
22
(
B
21
−
B
11
)
M
5
=
(
A
11
+
A
22
)
(
B
11
+
B
22
)
M
6
=
(
A
12
−
A
22
)
(
B
21
+
B
22
)
M
7
=
(
A
11
−
A
21
)
(
B
11
+
B
12
)
\begin{aligned} &M_1=A_{11}(B_{12}-B_{22}) \\ &M_2=(A_{11}+A_{12})B_{22} \\ &M_3=(A_{21}+A_{22})B_{11} \\ &M_4=A_{22}(B_{21}-B_{11}) \\ &M_5=(A_{11}+A_{22})(B_{11}+B_{22}) \\ &M_6=(A_{12}-A_{22})(B_{21}+B_{22}) \\ &M_7=(A_{11}-A_{21})(B_{11}+B_{12}) \\ \end{aligned}
M1=A11(B12−B22)M2=(A11+A12)B22M3=(A21+A22)B11M4=A22(B21−B11)M5=(A11+A22)(B11+B22)M6=(A12−A22)(B21+B22)M7=(A11−A21)(B11+B12)
则,矩阵 C
中各分块即可表示为:
C 11 = M 5 + M 4 − M 2 + M 6 C 12 = M 1 + M 2 C 21 = M 3 + M 4 C 22 = M 5 + M 1 − M 3 − M 7 \begin{aligned} &C_{11}=M_5+M_4-M_2+M_6 \\ &C_{12}=M_1+M_2 \\ &C_{21}=M_3+M_4 \\ &C_{22}=M_5+M_1-M_3-M_7 \\ \end{aligned} C11=M5+M4−M2+M6C12=M1+M2C21=M3+M4C22=M5+M1−M3−M7
子问题个数由原先的8个变成了7个,因此时间复杂度也随之变成 T ( n ) = O ( n l o g 2 7 ) ≈ O ( n 2.81 ) T(n)=O(n^{log_27}) \approx O(n^{2.81}) T(n)=O(nlog27)≈O(n2.81) 在效率上有一定的提升。