【算法导论】4.2

矩阵乘法的 Strassen 算法


朴素算法时间复杂度: Θ ( n 3 ) Θ(n^3) Θ(n3)


一般分治算法:
(1) A = [ A 11 A 12 A 21 A 22 ]    B = [ B 11 B 12 B 21 B 22 ]    C = [ C 11 C 12 C 21 C 22 ] A=\left[ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix} \right] ~~B=\left[ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{matrix} \right] ~~ C=\left[ \begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{matrix} \right] \tag{1} A=[A11A21A12A22]  B=[B11B21B12B22]  C=[C11C21C12C22](1)
其中四个子矩阵的规模为 n / 2 n/2 n/2 则:
[ C 11 C 12 C 21 C 22 ] = [ A 11 A 12 A 21 A 22 ] . [ B 11 B 12 B 21 B 22 ] \left[ \begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{matrix} \right] =\left[ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix} \right] .\left[ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{matrix} \right] [C11C21C12C22]=[A11A21A12A22].[B11B21B12B22]
如此递归求解,则:
T ( n ) = { Θ ( 1 )                                  n = 1 8 T ( n / 2 ) + Θ ( n 2 )             n > 1 T(n)=\left\{ \begin{matrix} Θ(1) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~n=1\\ 8T(n/2)+Θ(n^2)~~~~~~~~~~~n>1 \end{matrix} \right. T(n)={Θ(1)                                n=18T(n/2)+Θ(n2)           n>1
解得 T ( n ) = Θ ( n 3 ) T(n)=Θ(n^3) T(n)=Θ(n3)


Strassen算法:

  1. 仍按 ( 1 ) (1) (1) 式将矩阵分解。
  2. 按一定公式计算 S 1 , S 2 . . . S 10 S_1,S_2...S_{10} S1,S2...S10(仅包含加减运算)。
  3. 按一定公式递归的计算7个矩阵积 P 1 , P 2 . . . P 7 P_1,P_2...P_7 P1,P2...P7;每个矩阵规模都是 n / 2 n/2 n/2
  4. 通过 P i P_i Pi矩阵的不同组合进行加减运算,得出 C 11 , C 12 , C 21 , C 22 C_{11},C_{12},C_{21},C_{22} C11,C12,C21,C22
  5. 合并 C 11 , C 12 , C 21 , C 22 C_{11},C_{12},C_{21},C_{22} C11,C12,C21,C22得出 C C C

得到此法递归式:
T ( n ) = { Θ ( 1 )                                  n = 1 7 T ( n / 2 ) + Θ ( n 2 )             n > 1 T(n)=\left\{ \begin{matrix} Θ(1) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~n=1\\ 7T(n/2)+Θ(n^2)~~~~~~~~~~~n>1 \end{matrix} \right. T(n)={Θ(1)                                n=17T(n/2)+Θ(n2)           n>1
解得 T ( n ) = Θ ( n l g 7 ) T(n)=Θ(n^{lg_7}) T(n)=Θ(nlg7)

代码如下:

#define _CRT_SECURE_NO_WARNINGS
#include<stdio.h>
#define N 20
/*矩阵加法,f==1 表示加,f==2 表示减*/
void ad(int n, int a[N][N], int b[N][N], int c[N][N], int f)
{
	int i, j;
	for (i = 1; i <= n; i++)
		for (j = 1; j <= n; j++)
			if (f == 1)
				c[i][j] = a[i][j] + b[i][j];
			else
				c[i][j] = a[i][j] - b[i][j];
	return;
}
/*递归函数*/
void cal(int n, int A[N][N], int B[N][N], int C[N][N])
{
	/*递归出口*/
	if (n == 1) 
	{
		C[1][1] = A[1][1] * B[1][1];
		return;
	}
	int a[6][N][N], b[6][N][N], c[6][N][N], s[12][N][N], p[12][N][N];
	int i, j;
	/*拆分A,B矩阵*/
	for (i = 1; i <= n / 2; i++)
		for (j = 1; j <= n / 2; j++)
		{
			a[1][i][j] = A[i][j];
			b[1][i][j] = B[i][j];
		}
	for (i = 1; i <= n / 2; i++)
		for (j = 1; j <= n / 2; j++)
		{
			a[2][i][j] = A[i][j + n / 2];
			b[2][i][j] = B[i][j + n / 2];
		}
	for (i = 1; i <= n / 2; i++)
		for (j = 1; j <= n / 2; j++)
		{
			a[3][i][j] = A[i + n / 2][j];
			b[3][i][j] = B[i + n / 2][j];
		}
	for (i = 1; i <= n / 2; i++)
		for (j = 1; j <= n / 2; j++)
		{
			a[4][i][j] = A[i + n / 2][j + n / 2];
			b[4][i][j] = B[i + n / 2][j + n / 2];
		}
	/*计算s1-s10*/
	ad(n / 2, b[2], b[4], s[1], 2);
	ad(n / 2, a[1], a[2], s[2], 1);
	ad(n / 2, a[3], a[4], s[3], 1);
	ad(n / 2, b[3], b[1], s[4], 2);
	ad(n / 2, a[1], a[4], s[5], 1);
	ad(n / 2, b[1], b[4], s[6], 1);
	ad(n / 2, a[2], a[4], s[7], 2);
	ad(n / 2, b[3], b[4], s[8], 1);
	ad(n / 2, a[1], a[3], s[9], 2);
	ad(n / 2, b[1], b[2], s[10], 1);
	/*7次递归计算*/
	cal(n / 2, a[1], s[1], p[1]);
	cal(n / 2, s[2], b[4], p[2]);
	cal(n / 2, s[3], b[1], p[3]);
	cal(n / 2, a[4], s[4], p[4]);
	cal(n / 2, s[5], s[6], p[5]);
	cal(n / 2, s[7], s[8], p[6]);
	cal(n / 2, s[9], s[10], p[7]);
	/*计算C11*/
	ad(n / 2, p[5], p[4], c[1], 1);
	ad(n / 2, c[1], p[2], c[1], 2);
	ad(n / 2, c[1], p[6], c[1], 1);
	/*计算C12*/
	ad(n / 2, p[1], p[2], c[2], 1);
	/*计算C21*/
	ad(n / 2, p[3], p[4], c[3], 1);
	/*计算C22*/
	ad(n / 2, p[5], p[1], c[4], 1);
	ad(n / 2, c[4], p[3], c[4], 2);
	ad(n / 2, c[4], p[7], c[4], 2);
	/*将C11,C12,C21,C22合并成C*/
	for (i = 1; i <= n / 2; i++)
		for (j = 1; j <= n / 2; j++)
			C[i][j] = c[1][i][j];
	for (i = 1; i <= n / 2; i++)
		for (j = n / 2 + 1; j <= n; j++)
			C[i][j] = c[2][i][j - n / 2];
	for (i = n / 2 + 1; i <= n; i++)
		for (j = 1; j <= n / 2; j++)
			C[i][j] = c[3][i - n / 2][j];
	for (i = n / 2 + 1; i <= n; i++)
		for (j = n / 2 + 1; j <= n; j++)
			C[i][j] = c[4][i - n / 2][j - n / 2];
	return;
}
void main()
{
	int m, n, i, j, a[N][N] = { 0 }, b[N][N] = { 0 }, c[N][N] = { 0 };
	/*读入*/
	scanf("%d", &n);
	m = n;
	for (i = 1; i <= n; i++)
		for (j = 1; j <= n; j++)
			scanf("%d", &a[i][j]);
	for (i = 1; i <= n; i++)
		for (j = 1; j <= n; j++)
			scanf("%d", &b[i][j]);
	while ((n & (n - 1)) != 0)
		n++;
	/*计算*/
	cal(n, a, b, c);
	/*输出*/
	printf("\n");
	for (i = 1; i <= m; i++)
	{
		for (j = 1; j <= m; j++)
			printf("%4d ", c[i][j]);
		printf("\n");
	}
	getchar();
	getchar();
}

PS:当矩阵规模过大时,可能出现栈溢出。


4.2-3
该思考题提出若 n n n不是2的幂时该如何处理。
解决方法很容易:若不是2的幂,则用0扩充矩阵,直至其规模达到2的幂。
PS:若 ( n u m &amp; ( n u m − 1 = 0 ) ) (num\&amp;(num-1=0)) (num&(num1=0)) n u m num num是2的幂。


4.2-7
题目要求仅用3次实数乘法完成复数 a + b i 和 c + d i a+bi和c+di a+bic+di相乘(即得到 a c − b d 和 a d + b c ac-bd和ad+bc acbdad+bc)。
仿照Strassen方法:

令:
S 1 = ( a + b ) ∗ c = a c + b c S_1=(a+b)*c=ac+bc S1=(a+b)c=ac+bc
S 2 = ( c + d ) ∗ b = b c + b d S_2=(c+d)*b=bc+bd S2=(c+d)b=bc+bd
S 3 = ( b − a ) ∗ d = b d − a d S_3=(b-a)*d=bd-ad S3=(ba)d=bdad
则:
a c − b d = S 1 − S 2 ac-bd=S_1-S_2 acbd=S1S2
a d + b c = S 2 − S 3 ad+bc=S_2-S_3 ad+bc=S2S3


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值