算法导论 第四章矩阵乘法的Strassen算法

Strassen算法的核心思想是令递归树不那么茂盛一点,只递归进行7次而不是8次n/2*n/2矩阵的乘法。减少一次矩阵乘法带来的代价。

它包含4个步骤

1.将矩阵 A B C分解成n/2*n/2的子矩阵

2.创建10个n/2*n/2矩阵s1到s10 用来保存步骤1中的子矩阵的差 和 和 。

3.创建7个n/2*n/2矩阵p1到p7,用步骤1的子矩阵和2中的10个矩阵,递归的计算7个的积。

4.对3中的pi矩阵进行加减运算 C11=p5+p4-p2+p6   C12=p1+p2  C21=p3+p4  C22=p5+p1-p3-p7

5.再把C11 C12 C21 C22 赋回给C。

// 矩阵乘法之Strassen算法.cpp : 定义控制台应用程序的入口点。
//

#include "stdafx.h"
#include <iostream>
using namespace std;
const int N=4;
void input(int a[][N],int n)
{
	for(int i=0;i<n;i++)
	{
		for(int j=0;j<n;j++)
		{
			cin>>a[i][j];
		}
	}
	return ;
}
void output(int a[][N],int n)
{
	for(int i=0;i<n;i++)
	{
		for(int j=0;j<n;j++)
		{
			cout<<a[i][j]<<" ";
		}
		cout<<endl;
	}
	return ;
}
void MATRIX_MULTIPLY(int a[][N],int b[][N],int c[][N])//当N为2时直接计算 按普通方法
{
	for(int i=0;i<2;i++)
	{
		for(int j=0;j<2;j++)
		{
			c[i][j]=0;
			for(int k=0;k<2;k++)
			{
				c[i][j]=c[i][j]+a[i][k]*b[k][j];
			}
		}
	}
	return ;
}
void sum(int a[][N],int b[][N],int c[][N],int n)
{
	for(int i=0;i<n;i++)
	{
		for(int j=0;j<n;j++)
		{
		   c[i][j]=a[i][j]+b[i][j];
		}
	}
	return ;
}
void sub(int a[][N],int b[][N],int c[][N],int n)
{
	for(int i=0;i<n;i++)
	{
		for(int j=0;j<n;j++)
		{
		   c[i][j]=a[i][j]-b[i][j];
		}
	}
	return ;
}
void Strassen(int n,int a[][N],int b[][N],int c[][N])
{

	    int a11[N][N],a12[N][N],a21[N][N],a22[N][N];
		int b11[N][N],b12[N][N],b21[N][N],b22[N][N];
		int c11[N][N],c12[N][N],c21[N][N],c22[N][N];
		int s1[N][N],s2[N][N],s3[N][N],s4[N][N],s5[N][N],s6[N][N],s7[N][N],s8[N][N],s9[N][N],s10[N][N];
		int p1[N][N],p2[N][N],p3[N][N],p4[N][N],p5[N][N],p6[N][N],p7[N][N];
		int MM1[N][N],MM2[N][N];
	if(n==2)MATRIX_MULTIPLY(a,b,c);
	else
	{
	
		for(int i=0;i<n/2;i++)//第一步把a b c矩阵分解为N*N的子矩阵
		{
			for(int j=0;j<n/2;j++)
			{
				a11[i][j]=a[i][j];
				a12[i][j]=a[i][j+n/2];
				a21[i][j]=a[i+n/2][j];
				a22[i][j]=a[i+n/2][j+n/2];

				b11[i][j]=b[i][j];
				b12[i][j]=b[i][j+n/2];
				b21[i][j]=b[i+n/2][j];
				b22[i][j]=b[i+n/2][j+n/2];
			}
		}//a b分解完成   
	//第二步每个矩阵保存1中创建的两个矩阵的和或差
		sub(b12,b22,s1,n/2);
		sum(a11,a12,s2,n/2);
		sum(a21,a22,s3,n/2);
		sub(b21,b11,s4,n/2);
		sum(a11,a22,s5,n/2);
		sum(b11,b22,s6,n/2);
		sub(a12,a22,s7,n/2);
		sum(b21,b22,s8,n/2);
		sub(a11,a21,s9,n/2);
		sum(b11,b12,s10,n/2);
		
	//第三步利用1中建立的子矩阵和2中建立的10个矩阵递归的计算7个矩阵的积每个矩阵pi都是N
      
		Strassen(n/2,a11,s1,p1);
		Strassen(n/2,s2,b22,p2);
		Strassen(n/2,s3,b11,p3);
		Strassen(n/2,a22,s4,p4);
		Strassen(n/2,s5,s6,p5);
		Strassen(n/2,s7,s8,p6);
		Strassen(n/2,s9,s10,p7);
	//对3中创建的pi矩阵进行加减法运算,并计算出4个n/2*n/2的子矩阵
		sum(p5,p4,MM1,N/2);
		sub(p2,p6,MM2,N/2);
		sub(MM1,MM2,c11,N/2);//c11=p5+p4-p2+p6
		sum(p1,p2,c12,N/2);//c12=p1+p2
		sum(p3,p4,c21,N/2);//c21=p3+p4
		sum(p5,p1,MM1,N/2);//c21=p3+p4
		sum(p3,p7,MM2,N/2);
		sub(MM1,MM2,c22,N/2);//c11=p5+P1-P3-P7
		for(int i=0;i<n/2;i++)
		{
			for(int j=0;j<n/2;j++)
			{
				c[i][j]=c11[i][j];
				c[i][j+n/2]=c12[i][j];
				c[i+n/2][j]=c21[i][j];
				c[i+n/2][j+n/2]=c22[i][j];
			}
		}

	}
	return ;
}

int main ()
{
	int a[N][N];
	int b[N][N];
	int c[N][N];
	cout<<"输入矩阵A"<<endl;
	input(a,N);
	cout<<"输入矩阵B"<<endl;
	input(b,N);
	Strassen(N,a,b,c);
	cout<<"相乘之后的矩阵为\n";
	output(c,N);
	return 1;

}


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值