分治算法---矩阵乘法

[实验目的]

掌握分治策略的基本思想以及用分治法解决问题的技巧,运用分治法解决矩阵乘法的复杂度过高的问题。

【问题描述】设A和B是两个n*n阶矩阵,求它们的乘积矩阵C。(假设n=2k)。

【提示】A和B是两个n*n阶矩阵,它们的乘积C=A*B也是一个n*n阶矩阵,用传统方法计算矩阵C中的元素的公式如下:

               

由公式(1)可以知道,每计算一个C[i][j],需要做n次乘法和n-1次加法。因此计算C的n*n个元素需要n3次乘法和n3-n2次加法,因此算法的时间复杂度为O(n3)。现在使用分治法来降低算法的时间复杂度。如果将矩阵A,B,C中每一矩阵都分成4个大小相等的子矩阵,每个子矩阵是(n/2)*(n/2)的方阵,则可以将方阵C=A*B重写如下所示:

 

因此可得:

C11=A11*B11+A12*B21

C12=A11*B12+A12*B22

C21=A21*B11+A22*B21

C22=A21*B12+A22*B22

这样将2个n阶矩阵的乘积改为计算8个n/2阶矩阵的乘积和4个n/2阶矩阵的加法运算。当n=1时,2个1阶方阵的乘积可以直接算出,只需要做1次乘法。当子矩阵n>1时,为求两个子矩阵的乘积,可继续对两个子矩阵进行划分,直到子矩阵的阶为1可以直接计算为止。

但是这个算法并没有降低算法的时间复杂度,下面我们改用Strassen矩阵乘法来求C矩阵,具体的公式如下:

M1=A11*(B12-B22)

M2=(A11+A12)*B22

M3=(A21+A22)*B11

M4=A22*(B21-B11)

M5=(A11+A22)*(B11+B22)

M6=(A12-A22)*(B21+B22)

M7=(A11-A21)*(B11+B12)

则C矩阵的四个子矩阵的计算如下所示:

C11=M5+M4-M2+M6

C12=M1+M2

C21=M3+M4

C22=M5+M1-M3-M7

此算法共进行7次矩阵乘法,算法复杂度得到有效降低。

【主要算法框架】

The algorithm as follow:

void Strassen(n,A,B,C);
{   if (n==2)  

           MatrixMultiply(A,B,C);

     else {  divide A and B depend on formula (1);

                Strassen(n/2,A11,B12-B22,M1);

                Strassen(n/2,A11+A12,B22,M2);

                Strassen(n/2,A21+A22,B11,M3);

                Strassen(n/2,A22,B21-B11,M4);

                Strassen(n/2,A11+A22,B11+B22,M5);

                Strassen(n/2,A12-A22,B21+B22,M6);

                Strassen(n/2,A11-A21,B11+B12,M7);

         }

}

如下介绍一些参考算法 分析一些大佬的算法进行学习

#include <iostream>
#include <cstddef>
#include <cstdlib>
#include <ctime>
 
using namespace std;
 
int *InitMatrix(int row,int col);//初始化 
void FillMatrix(int *MatrixA, int size);//自动填充 
void PrintMatrix(int *MatrixA,int size);//打印矩阵 
void AddMatrix(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size);//加 
void SubMatrix(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size);//减 
void SplitMatrix(int *MatrixIn,int *MatrixOut,int size,int part);//四分
void StitchMatrix(int *PartA,int *PartB,int *PartC,int *PartD,int *Result,int size);//反着拼回去
void Strassen(int *MA,int *MB,int *MC,int size);  //Strassen算法
void GradeSchool(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size);//对比算法 
 
 
int main()
{
	clock_t StartTimeS,EndTimeS,StartTimeG,EndTimeG;
	int MaSize = 0;
	cout << "Please input the row of matrix(it must be index of two,like,2,4,8):";
	cin >> MaSize;
	int *MA = NULL;//规避野指针 
	int *MB = NULL;
	int *MC = NULL;
	MA = InitMatrix(MaSize,MaSize);
	MB = InitMatrix(MaSize,MaSize);
	MC = InitMatrix(MaSize,MaSize);
	FillMatrix(MA,MaSize);
	FillMatrix(MB,MaSize); 
	cout << "Matrix A is:" << endl << endl;
	PrintMatrix(MA,MaSize);
	cout << "Matrix B is:" << endl << endl;
	PrintMatrix(MB,MaSize);
	cout << "Matrix A and B are generated!" << endl << "Start to caculate!" << endl;//提示填充完毕
	StartTimeS = clock();
	Strassen(MA,MB,MC,MaSize);
	EndTimeS = clock();
	cout << "After Strassen multiplication the result is:" << endl << endl;
	PrintMatrix(MC,MaSize);
	StartTimeG = clock();
	GradeSchool(MA,MB,MC,MaSize);
	EndTimeG = clock();
	cout << "After Strassen multiplication the result is:" << endl << endl;
	PrintMatrix(MC,MaSize);
	cout << "Strassen method starts at:" << StartTimeS << endl << "ends at:" << EndTimeS << endl;
	cout << "Grade-School method starts at:" << StartTimeG << endl << "ends at:" << EndTimeG << endl;
	
	free(MA);//释放空间 
	free(MB);
	free(MC);
	
	return 0;
} 
 
int *InitMatrix(int row,int col)//初始化矩阵,大小事先不确定,所以需要动态分配  
{
	int *p;
	size_t size = sizeof(int)*row*col;//需要开row*col个int类型大小的空间 
	if (NULL == (p = (int *)malloc(size)))  
    {
    	cout << "Error in InitMatrix!" << endl;
    	return NULL;
    }
    else  
		return p;    //返回矩阵首地址 
}
 
 
void FillMatrix( int *MatrixA, int size)
{
	 for(int row = 0; row < size; row ++)
    {
        for(int col = 0; col < size; col ++)
        {
           cin>>MatrixA[row*size + col];
        }
    }
}
 
void PrintMatrix(int *MatrixA,int size)
{
	//cout<<"The Matrix is:"<<endl;
	for(int row = 0; row < size; row ++)
	{
		for(int col = 0; col < size; col ++)
		{
			cout << MatrixA[row*size + col] << "\t";
			if ((col + 1) % ((size)) == 0)
				cout << endl;
		}
	}
	cout << endl;
}
 
void AddMatrix(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size)
{
	for(int i = 0;i < size*size;i ++)
	{
		MatrixOut[i] = MatrixIn1[i] + MatrixIn2[i];
	}
}
 
void SubMatrix(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size)
{
	for(int i = 0;i < size*size;i ++)
	{
		MatrixOut[i] = MatrixIn1[i] - MatrixIn2[i];
	}
}
 
 
void SplitMatrix(int *MatrixIn,int *MatrixOut,int size,int part)
{
	int n = size/2;//编写方便 
	switch(part)
	{
		case 1://四分左上 
		{
			for (int i = 0;i < n;i ++)  
            {  
                for (int j = 0;j < n;j ++)  
                {  
                    MatrixOut[i*n + j] = MatrixIn[i*n + j];  
                }  
            }  
            break;  
		}
		case 2://四分右上 
		{
			for (int i = 0;i < n;i ++)  
            {  
                for (int j = 0;j < n;j ++)  
                {  
                    MatrixOut[i*n + j] = MatrixIn[i*n + j + n];  
                }  
            }  
            break;  
		}
		case 3://四分左下 
		{
			for (int i = 0; i < n; i ++)  
            {  
                for (int j = 0; j < n; j ++)  
                {  
                    MatrixOut[i*n + j] = MatrixIn[(i + n)*n + j];  
                }  
            }  
            break;  
		}
		case 4://四分右下 
		{
			for (int i = 0; i < n; i ++)  
            {  
                for (int j = 0; j< n; j ++)  
                {  
                    MatrixOut[i*n + j] = MatrixIn[(i + n)*n + j + n];  
                }  
            }  
            break;  
		}
		default :  
        	cout<<"Error in SplitMatrix!"; 
	}
}
 
void StitchMatrix(int *PartA,int *PartB,int *PartC,int *PartD,int *Result,int size)//反着拼回去 
{
	for(int i = 0; i < size; i ++)  
    {  
        for(int j = 0; j < size; j ++)  
        {  
            Result[i*size*2 + j] = PartA[i*size + j];  
            Result[i*size*2 + j + size] = PartB[i*size + j];  
            Result[(i + size)*size*2 + j] = PartC[i*size + j];  
            Result[(i + size)*size*2 + j + size] = PartD[i*size + j];  
        }  
    }  
}
/*
/Strassen算法:
/分块,分到2*2
*/
void Strassen(int *MA,int *MB,int *MC,int size)
{
	int n = size/2;
	if (2 == size)//这样就不用分了,以及分到最后执行这个不用再递归 
    {  
        int p1,p2,p3,p4,p5,p6,p7;  
        p1 = MA[0]*(MB[1]-MB[3]) ;  
        p2 = (MA[0] + MA[1])*MB[3] ;  
        p3 = (MA[2] + MA[3])*MB[0] ;  
        p4 = MA[3]*(MB[2] - MB[0]) ;  
        p5 = (MA[0] + MA[3])*(MB[0] + MB[3]) ;  
        p6 = (MA[1] - MA[3])*(MB[2] + MB[3]) ;  
        p7 = (MA[0] - MA[2])*(MB[0] + MB[1]) ;  
        MC[0] = p5 + p4 - p2 + p6 ;  
        MC[1] = p1 + p2 ;  
        MC[2] = p3 + p4 ;  
        MC[3] = p5 + p1 -p3 - p7 ;  
        return ;      
    }  
    else
	{
		int *MA1 = NULL,*MA2 = NULL,*MA3 = NULL,*MA4 = NULL;
		int *MB1 = NULL,*MB2 = NULL,*MB3 = NULL,*MB4 = NULL;
		int *MC1 = NULL,*MC2 = NULL,*MC3 = NULL,*MC4 = NULL;
		int *p1 = NULL,*p2 = NULL,*p3 = NULL,*p4 = NULL,*p5 = NULL,*p6 = NULL,*p7 = NULL;
		int *TEMP1 = NULL,*TEMP2 = NULL;
		
		
		MA1 = InitMatrix(n,n);
		MA2 = InitMatrix(n,n);
		MA3 = InitMatrix(n,n);
		MA4 = InitMatrix(n,n);
		MB1 = InitMatrix(n,n);
		MB2 = InitMatrix(n,n);
		MB3 = InitMatrix(n,n);
		MB4 = InitMatrix(n,n);
		MC1 = InitMatrix(n,n);
		MC2 = InitMatrix(n,n);
		MC3 = InitMatrix(n,n);
		MC4 = InitMatrix(n,n);
		p1 = InitMatrix(n,n);
		p2 = InitMatrix(n,n);
		p3 = InitMatrix(n,n);
		p4 = InitMatrix(n,n);
		p5 = InitMatrix(n,n);
		p6 = InitMatrix(n,n);
		p7 = InitMatrix(n,n);
		TEMP1 = InitMatrix(n,n);
		TEMP2 = InitMatrix(n,n);
		
		SplitMatrix(MA,MA1,size,1);SplitMatrix(MA,MA2,size,2);SplitMatrix(MA,MA3,size,3);SplitMatrix(MA,MA4,size,4);
		SplitMatrix(MB,MB1,size,1);SplitMatrix(MB,MB2,size,2);SplitMatrix(MB,MB3,size,3);SplitMatrix(MB,MB4,size,4);
 
		/*///
		/* p1=a(f-h)
		/* p2=h(a+b)
		/* p3=e(c+d)
		/* p4=d(g+e)
		/* p5=(e+h)(a+d)
		/* p6=(g+h)(b-d)
		/* p7=(a-c)(e+f)
		/*A a1  b2     B  e1  f2
		/*  c3  d4        g3  h4
		///*/
		
		//p1
		SubMatrix(MB2,MB4,TEMP1,n);
		Strassen(MA1,TEMP1,p1,n);
		//p2
		AddMatrix(MA1,MA2,TEMP1,n);
		Strassen(MB4,TEMP1,p2,n);
		//P3
		AddMatrix(MA3,MA4,TEMP1,n);
		Strassen(MB1,TEMP1,p3,n);	
		//P4
		AddMatrix(MB3,MB1,TEMP1,n);
		Strassen(MA4,TEMP1,p4,n);
		//P5
		AddMatrix(MB1,MB4,TEMP1,n);
		AddMatrix(MA1,MA4,TEMP2,n);
		Strassen(TEMP1,TEMP2,p5,n);
		//P6
		AddMatrix(MB3,MB4,TEMP1,n);
		SubMatrix(MA2,MA4,TEMP1,n);
		Strassen(TEMP1,TEMP2,p6,n);
		//P7
		AddMatrix(MB1,MB2,TEMP1,n);
		SubMatrix(MA1,MA3,TEMP2,n);
		Strassen(TEMP1,TEMP2,p7,n);
		
		//C1=P5+P4+P6-P2
		AddMatrix(p5,p4,TEMP1,n);
		AddMatrix(TEMP1,p6,TEMP2,n);
		SubMatrix(TEMP2,p2,MC1,n);
		
		//C2=P1+P2
		AddMatrix(p1,p2,MC2,n);
		
		//C3=P3+P4
		AddMatrix(p3,p4,MC3,n);
		
		//C4=P5+P1-P3-P7
		AddMatrix(p5,p1,TEMP1,n);
		SubMatrix(TEMP1,p3,TEMP2,n);
		SubMatrix(TEMP2,p7,MC4,n);
		
		StitchMatrix(MC1,MC2,MC3,MC4,MC,n);
		
		free(MA1);free(MA2);free(MA3);free(MA4);
		free(MB1);free(MB2);free(MB3);free(MB4);
		free(MC1);free(MC2);free(MC3);free(MC4);
		free(p1);free(p2);free(p3);free(p4);free(p5);free(p6);free(p7);
		free(TEMP1);free(TEMP2);
		
		return ;
	} 
}
 
void GradeSchool(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size)
{
	for (int i = 0; i < size; i ++)
    {
        for (int j = 0; j < size; j ++)
        {
			MatrixOut[i*size + j] = 0;
            for (int k = 0; k < size; k ++)
            {
                MatrixOut[i*size + j] = MatrixOut[i*size + j] + MatrixIn1[i*size + k]*MatrixIn2[k*size + j];
            }
        }
    }
}

如下 还有一个大佬的代码挺好 但是不知道为什么运行中总是出现错误 这里后续会进行分析改进

/*
Strassen矩阵乘法:
A和B的乘积矩阵C中的元素C[i,j]定义为:C[i][j] = k从1到n 累加A[i][k]*B[k][j]
若依此定义来计算A和B的乘积矩阵C,则每计
算C的一个元素C[i][j],需要做n次乘法和n-1次
加法。因此,算出矩阵C的 个元素所需的计算
时间为O(n3)
分治法:
将矩阵A,B和C中的每一矩阵都分成4个大小相等的子矩阵。可以将方程C=AB重写为:
[C11 C12]  =  [A11 A12]  [B11 B12]
[C21 C22]     [A21 A22]  [B21 B22]
可以得到
C11 = A11*B11 + A12*B21
C12 = A11*B12 + A12*B22
C21 = A21*B11 + A22*B21
C22 = A21*B12 + A22*B22
为了降低时间复杂度,必须减少乘法的次数:
从8次乘法降为7次,用了7次对于n/2矩阵乘的递归调用和18次n.2矩阵的加减法计算
M1 = A11(B12 - B22)
M2 = (A11+A12)*B22
以下代码是转载的,有空的时候再研究
*/
#include <iostream>
 
using namespace std;
 
const int N=4; //常量N用来定义矩阵的大小
 
void main()
{
    void STRASSEN(int n,float A[][N],float B[][N],float C[][N]);
    void input(int n,float p[][N]);
    void output(int n,float C[][N]);                    //函数声明部分
 
    float A[N][N],B[N][N],C[N][N];  //定义三个矩阵A,B,C
 
    cout<<"现在录入矩阵A[N][N]:"<<endl<<endl;
    input(N,A);
    cout<<endl<<"现在录入矩阵B[N][N]:"<<endl<<endl;
    input(N,B);                         //录入数组
 
    STRASSEN(N,A,B,C);   //调用STRASSEN函数计算
 
    output(N,C);  //输出计算结果
	getchar();
}
 
 
void input(int n,float p[][N])  //矩阵输入函数
{
    int i,j;
 
    for(i=0;i<n;i++)
    {
        cout<<"请输入第"<<i+1<<"行"<<endl;
        for(j=0;j<n;j++)
            cin>>p[i][j];
    }
}
 
void output(int n,float C[][N]) //据矩阵输出函数
{
    int i,j;
    cout<<"输出矩阵:"<<endl;
    for(i=0;i<n;i++)
    {
        cout<<endl;
        for(j=0;j<n;j++)
            cout<<C[i][j]<<"  ";
    }
    cout<<endl<<endl;
}
 
void MATRIX_MULTIPLY(float A[][N],float B[][N],float C[][N])  //按通常的矩阵乘法计算C=AB的子算法(仅做2阶)
{
    int i,j,t;
    for(i=0;i<2;i++)                     //计算A*B-->C
        for(j=0;j<2;j++)
        {    
            C[i][j]=0;                   //计算完一个C[i][j],C[i][j]应重新赋值为零
            for(t=0;t<2;t++)
            C[i][j]=C[i][j]+A[i][t]*B[t][j];
        }
}
 
void MATRIX_ADD(int n,float X[][N],float Y[][N],float Z[][N]) //矩阵加法函数X+Y—>Z
{
    int i,j;
    for(i=0;i<n;i++)
        for(j=0;j<n;j++)
            Z[i][j]=X[i][j]+Y[i][j];
}
 
void MATRIX_SUB(int n,float X[][N],float Y[][N],float Z[][N]) //矩阵减法函数X-Y—>Z
{
    int i,j;
    for(i=0;i<n;i++)
        for(j=0;j<n;j++)
            Z[i][j]=X[i][j]-Y[i][j];
 
}
 
 
void STRASSEN(int n,float A[][N],float B[][N],float C[][N])  //STRASSEN函数(递归)
{
    float A11[N][N],A12[N][N],A21[N][N],A22[N][N];
    float B11[N][N],B12[N][N],B21[N][N],B22[N][N];
    float C11[N][N],C12[N][N],C21[N][N],C22[N][N];
    float M1[N][N],M2[N][N],M3[N][N],M4[N][N],M5[N][N],M6[N][N],M7[N][N];
    float AA[N][N],BB[N][N],MM1[N][N],MM2[N][N];
 
    int i,j;//,x;
 
 
    if (n==2)
        MATRIX_MULTIPLY(A,B,C);//按通常的矩阵乘法计算C=AB的子算法(仅做2阶)
    else
    {
        for(i=0;i<n/2;i++)              
            for(j=0;j<n/2;j++)
 
                {
                    A11[i][j]=A[i][j];
                    A12[i][j]=A[i][j+n/2];
                    A21[i][j]=A[i+n/2][j];
                    A22[i][j]=A[i+n/2][j+n/2];
                    B11[i][j]=B[i][j];
                    B12[i][j]=B[i][j+n/2];
                    B21[i][j]=B[i+n/2][j];
                    B22[i][j]=B[i+n/2][j+n/2];
                }       //将矩阵A和B式分为四块
 
 
 
 
    MATRIX_SUB(n/2,B12,B22,BB);          
    STRASSEN(n/2,A11,BB,M1);//M1=A11(B12-B22)
 
    MATRIX_ADD(n/2,A11,A12,AA);
    STRASSEN(n/2,AA,B22,M2);//M2=(A11+A12)B22
 
    MATRIX_ADD(n/2,A21,A22,AA);
    STRASSEN(n/2,AA,B11,M3);//M3=(A21+A22)B11
 
    MATRIX_SUB(n/2,B21,B11,BB);
    STRASSEN(n/2,A22,BB,M4);//M4=A22(B21-B11)
 
    MATRIX_ADD(n/2,A11,A22,AA);
    MATRIX_ADD(n/2,B11,B22,BB);
    STRASSEN(n/2,AA,BB,M5);//M5=(A11+A22)(B11+B22)
 
    MATRIX_SUB(n/2,A12,A22,AA);
    MATRIX_SUB(n/2,B21,B22,BB);
    STRASSEN(n/2,AA,BB,M6);//M6=(A12-A22)(B21+B22)
 
    MATRIX_SUB(n/2,A11,A21,AA);
    MATRIX_SUB(n/2,B11,B12,BB);
    STRASSEN(n/2,AA,BB,M7);//M7=(A11-A21)(B11+B12)
    //计算M1,M2,M3,M4,M5,M6,M7(递归部分)
 
 
    MATRIX_ADD(N/2,M5,M4,MM1);                
    MATRIX_SUB(N/2,M2,M6,MM2);
    MATRIX_SUB(N/2,MM1,MM2,C11);//C11=M5+M4-M2+M6
 
    MATRIX_ADD(N/2,M1,M2,C12);//C12=M1+M2
 
    MATRIX_ADD(N/2,M3,M4,C21);//C21=M3+M4
 
    MATRIX_ADD(N/2,M5,M1,MM1);
    MATRIX_ADD(N/2,M3,M7,MM2);
    MATRIX_SUB(N/2,MM1,MM2,C22);//C22=M5+M1-M3-M7
 
    for(i=0;i<n/2;i++)
        for(j=0;j<n/2;j++)
        {
            C[i][j]=C11[i][j];
            C[i][j+n/2]=C12[i][j];
            C[i+n/2][j]=C21[i][j];
            C[i+n/2][j+n/2]=C22[i][j];
        }                                            //计算结果送回C[N][N]
 
    }
 
}

  • 9
    点赞
  • 70
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Simon_Smith

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值