4.2矩阵乘法的strassen算法
1.点积方法计算C=A*B,(A、B都为n*n的方阵)
c i j = ∑ k = 1 n a i k b k j c_{ij}=\sum_{k=1}^na_{ik}b_{kj} cij=k=1∑naikbkj
计算每个 c i j c_{ij} cij都要n步,而C矩阵共有n*n个 c i j c_{ij} cij需要计算,故时间复杂度为 O ( n 3 ) O(n^3) O(n3);
2.分治算法计算C
将A、B、C均分解为4个 n / 2 × n / 2 n/2\times n/2 n/2×n/2的子矩阵:
A = [ A 11 A 12 A 21 A 22 ] A=\begin{bmatrix} A_{11}& A_{12} \\ A_{21} & A_{22} \end{bmatrix} A=[A11A21A12A22] B = [ B 11 B 12 B 21 B 22 ] B=\begin{bmatrix} B_{11}& B_{12} \\ B_{21} & B_{22} \end{bmatrix} B=[B11B21B12B22] C = [ C 11 C 12 C 21 C 22 ] C=\begin{bmatrix} C_{11}& C_{12} \\ C_{21} & C_{22} \end{bmatrix} C=[C11C21C12C22]
故公式C=A·B改写为:
[ C 11 C 12 C 21 C 22 ] = [ A 11 A 12 A 21 A 22 ] ⋅ [ B 11 B 12 B 21 B 22 ] ⟺ \begin{bmatrix} C_{11}& C_{12} \\ C_{21} & C_{22} \end{bmatrix}=\begin{bmatrix} A_{11}& A_{12} \\ A_{21} & A_{22} \end{bmatrix}·\begin{bmatrix} B_{11}& B_{12} \\ B_{21} & B_{22} \end{bmatrix}\Longleftrightarrow [C11C21C12C22]=[A11A21A12A22]⋅[B11B21B12B22]⟺
C 11 = A 11 ⋅ B 11 + A 12 ⋅ B 21 C_{11}=A_{11}·B_{11}+A_{12}·B_{21} C11=A11⋅B11+A12⋅B21
C 12 = A 11 ⋅ B 12 + A 12 ⋅ B 22 C_{12}=A_{11}·B_{12}+A_{12}·B_{22} C12=A11⋅B12+A12⋅B22
C 21 = A 21 ⋅ B 11 + A 22 ⋅ B 21 C_{21}=A_{21}·B_{11}+A_{22}·B_{21} C21=A21⋅B11+A22⋅B21
C 11 = A 21 ⋅ B 12 + A 22 ⋅ B 22 C_{11}=A_{21}·B_{12}+A_{22}·B_{22} C11=A21⋅B12+A22⋅B22
即计算C=A·B可以分解为8个子问题来设计出递归分治算法:
算法运算时间的递归式为:
T
(
n
)
=
{
O
(
1
)
n
=
1
8
T
(
n
/
2
)
+
O
(
n
2
)
n
>
1
(
注
)
四
次
加
法
计
算
需
要
4
×
n
/
2
×
n
/
2
步
T(n)=\begin{cases}O(1)&&n=1\\8T(n/2)+O(n^2)&&n>1\end{cases}\\(注)四次加法计算需要4\times n/2 \times n/2步
T(n)={O(1)8T(n/2)+O(n2)n=1n>1(注)四次加法计算需要4×n/2×n/2步
此时我们运用主方法可得,时间复杂度
T
(
n
)
=
O
(
n
3
)
T(n)=O(n^3)
T(n)=O(n3)
3.改进分治算法,strassen方法
srassen的核心思想是令递归树不那么茂盛,即递归只分解为7个子问题而不是8个,减少一个子问题的代价是进行而外基础子矩阵的常数次加法。
1.每个矩阵均分解为4个 n / 2 × n / 2 n/2\times n/2 n/2×n/2的子矩阵:
利用下标计算法,花费O(1)时间
2.创建10个个 n / 2 × n / 2 n/2\times n/2 n/2×n/2的矩阵
由1中分解的子矩阵和或差得到,花费 O ( n 2 ) O(n^2) O(n2)时间
3.递归的计算7个矩阵积
利用1和2中的矩阵
4.将7个矩阵积的不同组合进行加减运算,计算出 C 11 、 C 12 、 C 21 、 C 22 C_{11}、C_{12}、C_{21}、C_{22} C11、C12、C21、C22
算法运算时间的递归式为:
T
(
n
)
=
{
O
(
1
)
n
=
1
7
T
(
n
/
2
)
+
O
(
n
2
)
n
>
1
T(n)=\begin{cases}O(1)&&n=1\\7T(n/2)+O(n^2)&&n>1\end{cases}
T(n)={O(1)7T(n/2)+O(n2)n=1n>1
4.原理应用
如果可以用k次乘法操作完成两个3*3矩阵相乘,那么你可以在 O ( n l g 7 ) O(n^{lg7}) O(nlg7)时间内完成n*n的乘法操作,满足这一条件的最大k是多少?
T(n) = kT(n/3)+O(1),通过主定理得 l o g 3 k < l o g 7 log_3{k}<log7 log3k<log7,故k=21。