问题
求两个矩阵相乘的结果
思路
传统矩阵相乘
Strassen矩阵乘法
我们可以不断把规模大的矩阵缩小为一半的矩阵,由此减小了规模
通过减小规模,我们成功把时间复杂度为O(n^3) 降到了O(n^log7) 即O(n^2.81)
代码实现
这里我们只考虑矩阵规模为2的倍数的矩阵乘法
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#define N 4//N为2的倍数
void add(int n,int A[][N],int B[][N],int C[][N]){
int i,j;
for(i=0;i<n;i++){
for(j=0;j<n;j++){
C[i][j]=A[i][j]+B[i][j];
}
}
}
void sub(int n,int A[][N],int B[][N],int C[][N]){
int i,j;
for(i=0;i<n;i++){
for(j=0;j<n;j++){
C[i][j]=A[i][j]-B[i][j];
}
}
}
void Strassen(int n,int A[][N],int B[][N],int C[][N]){
int A11[N][N],A12[N][N],A21[N][N],A22[N][N];
int B11[N][N],B12[N][N],B21[N][N],B22[N][N];
int C11[N][N],C12[N][N],C21[N][N],C22[N][N];
int M1[N][N],M2[N][N],M3[N][N],M4[N][N],M5[N][N],M6[N][N],M7[N][N];
int T1[N][N],T2[N][N];
int i,j;
if(n==2){
for(i=0;i<2;i++){
for(j=0;j<2;j++){
C[i][j]=0;
for(int t=0;t<2;t++){
C[i][j]+=A[i][t]*B[t][j];
}
}
}
return ;
}
for(i=0;i<n/2;i++){//划分A,B
for(j=0;j<n/2;j++){
A11[i][j]=A[i][j];
A12[i][j]=A[i][j+n/2];
A21[i][j]=A[i+n/2][j];
A22[i][j]=A[i+n/2][j+n/2];
B11[i][j]=B[i][j];
B12[i][j]=B[i][j+n/2];
B21[i][j]=B[i+n/2][j];
B22[i][j]=B[i+n/2][j+n/2];
}
}
//M1=A11(B12-B22)
sub(n/2,B12,B22,T1);
Strassen(n/2,A11,T1,M1);
//M2=(A11+A12)B22
add(n/2,A11,A12,T1);
Strassen(n/2,T1,B22,M2);
//M3=(A21+A22)B11
add(n/2,A21,A22,T1);
Strassen(n/2,T1,B11,M3);
//M4=A22(B21-B11)
sub(n/2,B21,B11,T1);
Strassen(n/2,A22,T1,M4);
//M5=(A11+A22)(B11+B22)
add(n/2,A11,A22,T1);
add(n/2,B11,B22,T2);
Strassen(n/2,T1,T2,M5);
//M6=(A12-A22)(B21+B22)
sub(n/2,A12,A22,T1);
add(n/2,B21,B22,T2);
Strassen(n/2,T1,T2,M6);
//M7=(A11-A21)(B11+B12)
sub(n/2,A11,A21,T1);
add(n/2,B11,B12,T2);
Strassen(n/2,T1,T2,M7);
//C11=M5+M4-M2+M6
add(n/2,M5,M4,T1);
sub(n/2,T1,M2,T2);
add(n/2,T2,M6,C11);
//C12=M1+M2
add(n/2,M1,M2,C12);
//C21=M3+M4
add(n/2,M3,M4,C21);
//C22=M5+M1-M3-M7
add(n/2,M5,M1,T1);
sub(n/2,T1,M3,T2);
sub(n/2,T2,M7,C22);
for(i=0;i<n/2;i++){//合并到C
for(j=0;j<n/2;j++){
C[i][j]=C11[i][j];
C[i+n/2][j]=C21[i][j];
C[i][j+n/2]=C12[i][j];
C[i+n/2][j+n/2]=C22[i][j];
}
}
}
int main(){
int n,i,j;
int A[N][N],B[N][N],C[N][N];
for(i=0;i<N;i++){
for(j=0;j<N;j++){
scanf("%d",&A[i][j]);
}
}
for(i=0;i<N;i++){
for(j=0;j<N;j++){
scanf("%d",&B[i][j]);
}
}
Strassen(N,A,B,C);
for(i=0;i<N;i++){
for(j=0;j<N;j++){
printf("%d\t",C[i][j]);
}
printf("\n");
}
return 0;
}