首先介绍一下矩阵A,B相乘的几种方法:
1.暴力计算法:利用三重循环进行计算,时间复杂度为O(n^3)。
void NormalMul(int** a,int** b,int **c,int n)//普通方法
{
int i,j,k;
for(i=0;i<n;i++){
for(j=0;j<n;j++){
int m=0;
for(k=0;k<n;k++)m+=a[i][k]*b[k][j];
c[i][j]=m;
}
}
}
2.分治算法:其时间复杂度也是O(n^3),所以不在详细说。
3.Strassen算法:其本质上也是使用了分治的思想,但是将时间复杂度优化到了O(n^log7),在大规模矩阵乘法上有极大优势。
现在重新定义7个新矩阵
M1=(A11+A22)*(B11+B22)
M2=(A21+A22)*B11
M3=A11*(B12-B22)
M4=A22*(B21-B11)
M5=(A11+A12)*B22
M6=(A21-A11)*(B11+B12)
M7=(A12-A22)*(B21+B22)
结果矩阵C可以组合上述矩阵,如下
C11=M1+M4-M5+M7
C12=M3+M5
C21=M2+M4
C22=M1-M2+M3+M6
这时候共用了7次乘法,18次加减法运算. 写出递推公式T(n)=7T(n/2)+Θ(n^2). 最终结果是O(n^log7)=O(n^2.807)
void Add(int **a,int **b,int **c,int size)//计算a+b->c
{
int i,j;
for(i=0;i<size;i++){
for(j=0;j<size;j++){
c[i][j]=a[i][j]+b[i][j];
}
}
}
void Minus(int **a,int **b,int **c,int size)//计算a-b->c
{
int i,j;
for(i=0;i<size;i++){
for(j=0;j<size;j++){
c[i][j]=a[i][j]-b[i][j];
}
}
}
void S_Mul(int **a,int **b,int **c,int size)//Strassen算法
{
int half=size/2;
if(size==512)NormalMul(a,b,c,size);//这里条件可以改变
else{
int **a11,**a12,**a21,**a22;
int **b11,**b12,**b21,**b22;
int **c11,**c12,**c21,**c22;
int **aa,**bb;
int **p1,**p2,**p3,**p4,**p5,**p6,**p7;
//开辟空间
a11=(int**)malloc(sizeof(int*)*half);
a12=(int**)malloc(sizeof(int*)*half);
a21=(int**)malloc(sizeof(int*)*half);
a22=(int**)malloc(sizeof(int*)*half);
b11=(int**)malloc(sizeof(int*)*half);
b12=(int**)malloc(sizeof(int*)*half);
b21=(int**)malloc(sizeof(int*)*half);
b22=(int**)malloc(sizeof(int*)*half);
c11=(int**)malloc(sizeof(int*)*half);
c12=(int**)malloc(sizeof(int*)*half);
c21=(int**)malloc(sizeof(int*)*half);
c22=(int**)malloc(sizeof(int*)*half);
aa=(int**)malloc(sizeof(int*)*half);
bb=(int**)malloc(sizeof(int*)*half);
p1=(int**)malloc(sizeof(int*)*half);
p2=(int**)malloc(sizeof(int*)*half);
p3=(int**)malloc(sizeof(int*)*half);
p4=(int**)malloc(sizeof(int*)*half);
p5=(int**)malloc(sizeof(int*)*half);
p6=(int**)malloc(sizeof(int*)*half);
p7=(int**)malloc(sizeof(int*)*half);
int i,j;
for(i=0;i<half;i++){
a11[i]=(int*)malloc(sizeof(int)*half);
a12[i]=(int*)malloc(sizeof(int)*half);
a21[i]=(int*)malloc(sizeof(int)*half);
a22[i]=(int*)malloc(sizeof(int)*half);
b11[i]=(int*)malloc(sizeof(int)*half);
b12[i]=(int*)malloc(sizeof(int)*half);
b21[i]=(int*)malloc(sizeof(int)*half);
b22[i]=(int*)malloc(sizeof(int)*half);
c11[i]=(int*)malloc(sizeof(int)*half);
c12[i]=(int*)malloc(sizeof(int)*half);
c21[i]=(int*)malloc(sizeof(int)*half);
c22[i]=(int*)malloc(sizeof(int)*half);
aa[i]=(int*)malloc(sizeof(int)*half);
bb[i]=(int*)malloc(sizeof(int)*half);
p1[i]=(int*)malloc(sizeof(int)*half);
p2[i]=(int*)malloc(sizeof(int)*half);
p3[i]=(int*)malloc(sizeof(int)*half);
p4[i]=(int*)malloc(sizeof(int)*half);
p5[i]=(int*)malloc(sizeof(int)*half);
p6[i]=(int*)malloc(sizeof(int)*half);
p7[i]=(int*)malloc(sizeof(int)*half);
}
//分割数组
for(i=0;i<half;i++){
for(j=0;j<half;j++){
a11[i][j]=a[i][j];
a12[i][j]=a[i][j+half];
a21[i][j]=a[i+half][j];
a22[i][j]=a[i+half][j+half];
b11[i][j]=b[i][j];
b12[i][j]=b[i][j+half];
b21[i][j]=b[i+half][j];
b22[i][j]=b[i+half][j+half];
}
}
//开始计算
Minus(b12,b22,bb,half);
S_Mul(a11,bb,p1,half);
Add(a11,a12,aa,half);
S_Mul(aa,b22,p2,half);
Add(a21,a22,aa,half);
S_Mul(aa,b11,p3,half);
Minus(b21,b11,bb,half);
S_Mul(a22,bb,p4,half);
Add(a11,a22,aa,half);
Add(b11,b22,bb,half);
S_Mul(aa,bb,p5,half);
Minus(a12,a22,aa,half);
Add(b21,b22,bb,half);
S_Mul(aa,bb,p6,half);
Minus(a11,a21,aa,half);
Add(b11,b12,bb,half);
S_Mul(aa,bb,p7,half);
//c11
Add(p5,p4,c11,half);
Minus(c11,p2,c11,half);
Add(c11,p6,c11,half);
//c12
Add(p1,p2,c12,half);
//c21
Add(p3,p4,c21,half);
//c22
Add(p5,p1,c22,half);
Minus(c22,p3,c22,half);
Minus(c22,p7,c22,half);
//合并
for(i=0;i<half;i++){
for(j=0;j<half;j++){
c[i][j]=c11[i][j];
c[i][j+half]=c12[i][j];
c[i+half][j]=c21[i][j];
c[i+half][j+half]=c22[i][j];
}
}
//释放内存
for (i = 0; i < half;i++) {
free(a11[i]);
free(a12[i]);
free(a21[i]);
free(a22[i]);
free(b11[i]);
free(b12[i]);
free(b21[i]);
free(b22[i]);
free(c11[i]);
free(c12[i]);
free(c21[i]);
free(c22[i]);
free(p1[i]);
free(p2[i]);
free(p3[i]);
free(p4[i]);
free(p5[i]);
free(p6[i]);
free(p7[i]);
free(aa[i]);
free(bb[i]);
}
free(a11);
free(a12);
free(a21);
free(a22);
free(b11);
free(b12);
free(b21);
free(b22);
free(c11);
free(c12);
free(c21);
free(c22);
free(p1);
free(p2);
free(p3);
free(p4);
free(p5);
free(p6);
free(p7);
free(aa);
free(bb);
}
}
以上是Strassen算法的代码。