Strassen算法的核心思想是令递归树不那么茂盛一点,只递归进行7次而不是8次n/2*n/2矩阵的乘法。减少一次矩阵乘法带来的代价。
它包含4个步骤
1.将矩阵 A B C分解成n/2*n/2的子矩阵
2.创建10个n/2*n/2矩阵s1到s10 用来保存步骤1中的子矩阵的差 和 和 。
3.创建7个n/2*n/2矩阵p1到p7,用步骤1的子矩阵和2中的10个矩阵,递归的计算7个的积。
4.对3中的pi矩阵进行加减运算 C11=p5+p4-p2+p6 C12=p1+p2 C21=p3+p4 C22=p5+p1-p3-p7
5.再把C11 C12 C21 C22 赋回给C。
// 矩阵乘法之Strassen算法.cpp : 定义控制台应用程序的入口点。
//
#include "stdafx.h"
#include <iostream>
using namespace std;
const int N=4;
void input(int a[][N],int n)
{
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
cin>>a[i][j];
}
}
return ;
}
void output(int a[][N],int n)
{
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
cout<<a[i][j]<<" ";
}
cout<<endl;
}
return ;
}
void MATRIX_MULTIPLY(int a[][N],int b[][N],int c[][N])//当N为2时直接计算 按普通方法
{
for(int i=0;i<2;i++)
{
for(int j=0;j<2;j++)
{
c[i][j]=0;
for(int k=0;k<2;k++)
{
c[i][j]=c[i][j]+a[i][k]*b[k][j];
}
}
}
return ;
}
void sum(int a[][N],int b[][N],int c[][N],int n)
{
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
c[i][j]=a[i][j]+b[i][j];
}
}
return ;
}
void sub(int a[][N],int b[][N],int c[][N],int n)
{
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
c[i][j]=a[i][j]-b[i][j];
}
}
return ;
}
void Strassen(int n,int a[][N],int b[][N],int c[][N])
{
int a11[N][N],a12[N][N],a21[N][N],a22[N][N];
int b11[N][N],b12[N][N],b21[N][N],b22[N][N];
int c11[N][N],c12[N][N],c21[N][N],c22[N][N];
int s1[N][N],s2[N][N],s3[N][N],s4[N][N],s5[N][N],s6[N][N],s7[N][N],s8[N][N],s9[N][N],s10[N][N];
int p1[N][N],p2[N][N],p3[N][N],p4[N][N],p5[N][N],p6[N][N],p7[N][N];
int MM1[N][N],MM2[N][N];
if(n==2)MATRIX_MULTIPLY(a,b,c);
else
{
for(int i=0;i<n/2;i++)//第一步把a b c矩阵分解为N*N的子矩阵
{
for(int j=0;j<n/2;j++)
{
a11[i][j]=a[i][j];
a12[i][j]=a[i][j+n/2];
a21[i][j]=a[i+n/2][j];
a22[i][j]=a[i+n/2][j+n/2];
b11[i][j]=b[i][j];
b12[i][j]=b[i][j+n/2];
b21[i][j]=b[i+n/2][j];
b22[i][j]=b[i+n/2][j+n/2];
}
}//a b分解完成
//第二步每个矩阵保存1中创建的两个矩阵的和或差
sub(b12,b22,s1,n/2);
sum(a11,a12,s2,n/2);
sum(a21,a22,s3,n/2);
sub(b21,b11,s4,n/2);
sum(a11,a22,s5,n/2);
sum(b11,b22,s6,n/2);
sub(a12,a22,s7,n/2);
sum(b21,b22,s8,n/2);
sub(a11,a21,s9,n/2);
sum(b11,b12,s10,n/2);
//第三步利用1中建立的子矩阵和2中建立的10个矩阵递归的计算7个矩阵的积每个矩阵pi都是N
Strassen(n/2,a11,s1,p1);
Strassen(n/2,s2,b22,p2);
Strassen(n/2,s3,b11,p3);
Strassen(n/2,a22,s4,p4);
Strassen(n/2,s5,s6,p5);
Strassen(n/2,s7,s8,p6);
Strassen(n/2,s9,s10,p7);
//对3中创建的pi矩阵进行加减法运算,并计算出4个n/2*n/2的子矩阵
sum(p5,p4,MM1,N/2);
sub(p2,p6,MM2,N/2);
sub(MM1,MM2,c11,N/2);//c11=p5+p4-p2+p6
sum(p1,p2,c12,N/2);//c12=p1+p2
sum(p3,p4,c21,N/2);//c21=p3+p4
sum(p5,p1,MM1,N/2);//c21=p3+p4
sum(p3,p7,MM2,N/2);
sub(MM1,MM2,c22,N/2);//c11=p5+P1-P3-P7
for(int i=0;i<n/2;i++)
{
for(int j=0;j<n/2;j++)
{
c[i][j]=c11[i][j];
c[i][j+n/2]=c12[i][j];
c[i+n/2][j]=c21[i][j];
c[i+n/2][j+n/2]=c22[i][j];
}
}
}
return ;
}
int main ()
{
int a[N][N];
int b[N][N];
int c[N][N];
cout<<"输入矩阵A"<<endl;
input(a,N);
cout<<"输入矩阵B"<<endl;
input(b,N);
Strassen(N,a,b,c);
cout<<"相乘之后的矩阵为\n";
output(c,N);
return 1;
}