矩阵乘法的 Strassen 算法
朴素算法时间复杂度: Θ ( n 3 ) Θ(n^3) Θ(n3);
一般分治算法:
(1)
A
=
[
A
11
A
12
A
21
A
22
]
B
=
[
B
11
B
12
B
21
B
22
]
C
=
[
C
11
C
12
C
21
C
22
]
A=\left[ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix} \right] ~~B=\left[ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{matrix} \right] ~~ C=\left[ \begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{matrix} \right] \tag{1}
A=[A11A21A12A22] B=[B11B21B12B22] C=[C11C21C12C22](1)
其中四个子矩阵的规模为
n
/
2
n/2
n/2 则:
[
C
11
C
12
C
21
C
22
]
=
[
A
11
A
12
A
21
A
22
]
.
[
B
11
B
12
B
21
B
22
]
\left[ \begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{matrix} \right] =\left[ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix} \right] .\left[ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{matrix} \right]
[C11C21C12C22]=[A11A21A12A22].[B11B21B12B22]
如此递归求解,则:
T
(
n
)
=
{
Θ
(
1
)
n
=
1
8
T
(
n
/
2
)
+
Θ
(
n
2
)
n
>
1
T(n)=\left\{ \begin{matrix} Θ(1) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~n=1\\ 8T(n/2)+Θ(n^2)~~~~~~~~~~~n>1 \end{matrix} \right.
T(n)={Θ(1) n=18T(n/2)+Θ(n2) n>1
解得
T
(
n
)
=
Θ
(
n
3
)
T(n)=Θ(n^3)
T(n)=Θ(n3)
Strassen算法:
- 仍按 ( 1 ) (1) (1) 式将矩阵分解。
- 按一定公式计算 S 1 , S 2 . . . S 10 S_1,S_2...S_{10} S1,S2...S10(仅包含加减运算)。
- 按一定公式递归的计算7个矩阵积 P 1 , P 2 . . . P 7 P_1,P_2...P_7 P1,P2...P7;每个矩阵规模都是 n / 2 n/2 n/2。
- 通过 P i P_i Pi矩阵的不同组合进行加减运算,得出 C 11 , C 12 , C 21 , C 22 C_{11},C_{12},C_{21},C_{22} C11,C12,C21,C22。
- 合并 C 11 , C 12 , C 21 , C 22 C_{11},C_{12},C_{21},C_{22} C11,C12,C21,C22得出 C C C。
得到此法递归式:
T
(
n
)
=
{
Θ
(
1
)
n
=
1
7
T
(
n
/
2
)
+
Θ
(
n
2
)
n
>
1
T(n)=\left\{ \begin{matrix} Θ(1) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~n=1\\ 7T(n/2)+Θ(n^2)~~~~~~~~~~~n>1 \end{matrix} \right.
T(n)={Θ(1) n=17T(n/2)+Θ(n2) n>1
解得
T
(
n
)
=
Θ
(
n
l
g
7
)
T(n)=Θ(n^{lg_7})
T(n)=Θ(nlg7)
代码如下:
#define _CRT_SECURE_NO_WARNINGS
#include<stdio.h>
#define N 20
/*矩阵加法,f==1 表示加,f==2 表示减*/
void ad(int n, int a[N][N], int b[N][N], int c[N][N], int f)
{
int i, j;
for (i = 1; i <= n; i++)
for (j = 1; j <= n; j++)
if (f == 1)
c[i][j] = a[i][j] + b[i][j];
else
c[i][j] = a[i][j] - b[i][j];
return;
}
/*递归函数*/
void cal(int n, int A[N][N], int B[N][N], int C[N][N])
{
/*递归出口*/
if (n == 1)
{
C[1][1] = A[1][1] * B[1][1];
return;
}
int a[6][N][N], b[6][N][N], c[6][N][N], s[12][N][N], p[12][N][N];
int i, j;
/*拆分A,B矩阵*/
for (i = 1; i <= n / 2; i++)
for (j = 1; j <= n / 2; j++)
{
a[1][i][j] = A[i][j];
b[1][i][j] = B[i][j];
}
for (i = 1; i <= n / 2; i++)
for (j = 1; j <= n / 2; j++)
{
a[2][i][j] = A[i][j + n / 2];
b[2][i][j] = B[i][j + n / 2];
}
for (i = 1; i <= n / 2; i++)
for (j = 1; j <= n / 2; j++)
{
a[3][i][j] = A[i + n / 2][j];
b[3][i][j] = B[i + n / 2][j];
}
for (i = 1; i <= n / 2; i++)
for (j = 1; j <= n / 2; j++)
{
a[4][i][j] = A[i + n / 2][j + n / 2];
b[4][i][j] = B[i + n / 2][j + n / 2];
}
/*计算s1-s10*/
ad(n / 2, b[2], b[4], s[1], 2);
ad(n / 2, a[1], a[2], s[2], 1);
ad(n / 2, a[3], a[4], s[3], 1);
ad(n / 2, b[3], b[1], s[4], 2);
ad(n / 2, a[1], a[4], s[5], 1);
ad(n / 2, b[1], b[4], s[6], 1);
ad(n / 2, a[2], a[4], s[7], 2);
ad(n / 2, b[3], b[4], s[8], 1);
ad(n / 2, a[1], a[3], s[9], 2);
ad(n / 2, b[1], b[2], s[10], 1);
/*7次递归计算*/
cal(n / 2, a[1], s[1], p[1]);
cal(n / 2, s[2], b[4], p[2]);
cal(n / 2, s[3], b[1], p[3]);
cal(n / 2, a[4], s[4], p[4]);
cal(n / 2, s[5], s[6], p[5]);
cal(n / 2, s[7], s[8], p[6]);
cal(n / 2, s[9], s[10], p[7]);
/*计算C11*/
ad(n / 2, p[5], p[4], c[1], 1);
ad(n / 2, c[1], p[2], c[1], 2);
ad(n / 2, c[1], p[6], c[1], 1);
/*计算C12*/
ad(n / 2, p[1], p[2], c[2], 1);
/*计算C21*/
ad(n / 2, p[3], p[4], c[3], 1);
/*计算C22*/
ad(n / 2, p[5], p[1], c[4], 1);
ad(n / 2, c[4], p[3], c[4], 2);
ad(n / 2, c[4], p[7], c[4], 2);
/*将C11,C12,C21,C22合并成C*/
for (i = 1; i <= n / 2; i++)
for (j = 1; j <= n / 2; j++)
C[i][j] = c[1][i][j];
for (i = 1; i <= n / 2; i++)
for (j = n / 2 + 1; j <= n; j++)
C[i][j] = c[2][i][j - n / 2];
for (i = n / 2 + 1; i <= n; i++)
for (j = 1; j <= n / 2; j++)
C[i][j] = c[3][i - n / 2][j];
for (i = n / 2 + 1; i <= n; i++)
for (j = n / 2 + 1; j <= n; j++)
C[i][j] = c[4][i - n / 2][j - n / 2];
return;
}
void main()
{
int m, n, i, j, a[N][N] = { 0 }, b[N][N] = { 0 }, c[N][N] = { 0 };
/*读入*/
scanf("%d", &n);
m = n;
for (i = 1; i <= n; i++)
for (j = 1; j <= n; j++)
scanf("%d", &a[i][j]);
for (i = 1; i <= n; i++)
for (j = 1; j <= n; j++)
scanf("%d", &b[i][j]);
while ((n & (n - 1)) != 0)
n++;
/*计算*/
cal(n, a, b, c);
/*输出*/
printf("\n");
for (i = 1; i <= m; i++)
{
for (j = 1; j <= m; j++)
printf("%4d ", c[i][j]);
printf("\n");
}
getchar();
getchar();
}
PS:当矩阵规模过大时,可能出现栈溢出。
4.2-3
该思考题提出若
n
n
n不是2的幂时该如何处理。
解决方法很容易:若不是2的幂,则用0扩充矩阵,直至其规模达到2的幂。
PS:若
(
n
u
m
&
(
n
u
m
−
1
=
0
)
)
(num\&(num-1=0))
(num&(num−1=0))则
n
u
m
num
num是2的幂。
4.2-7
题目要求仅用3次实数乘法完成复数
a
+
b
i
和
c
+
d
i
a+bi和c+di
a+bi和c+di相乘(即得到
a
c
−
b
d
和
a
d
+
b
c
ac-bd和ad+bc
ac−bd和ad+bc)。
仿照Strassen方法:
令:
S
1
=
(
a
+
b
)
∗
c
=
a
c
+
b
c
S_1=(a+b)*c=ac+bc
S1=(a+b)∗c=ac+bc
S
2
=
(
c
+
d
)
∗
b
=
b
c
+
b
d
S_2=(c+d)*b=bc+bd
S2=(c+d)∗b=bc+bd
S
3
=
(
b
−
a
)
∗
d
=
b
d
−
a
d
S_3=(b-a)*d=bd-ad
S3=(b−a)∗d=bd−ad
则:
a
c
−
b
d
=
S
1
−
S
2
ac-bd=S_1-S_2
ac−bd=S1−S2
a
d
+
b
c
=
S
2
−
S
3
ad+bc=S_2-S_3
ad+bc=S2−S3