- strassen矩阵乘,只完成了计算,未将计算的结果从返回的矩阵中提取。
#include<stdio.h>
#include<stdlib.h>
#include<sys/time.h>
#include<unistd.h>
#define bool int
int n,m,k,Max_dimention;
FILE *fp_in,*fp_out;
struct matrix{
int **A;
int col,row;
};
struct matrix X,Y,Z;
int max(int a,int b){return a>b?a:b;}
void Print(struct matrix A){
for(int i=0;i<A.row;i++)
{
for(int j=0;j<A.col;j++)
{
fprintf(fp_out,"%d ",A.A[i][j]);
}
fprintf(fp_out,"\n");
}
}
void freeMatrix(struct matrix A)
{
for(int i=0;i<A.row;i++)
free(A.A[i]);
free(A.A);
}
void Malloc0(){
X.row = n;
Y.row = X.col = m;
Y.col = k;
X.A = (int **)malloc(sizeof(int *)*n);
Y.A = (int **)malloc(sizeof(int *)*m);
for(int i=0;i<n;i++)
X.A[i] = (int *)malloc(sizeof(int)*m);
for(int i=0;i<m;i++)
Y.A[i] = (int *)malloc(sizeof(int)*k);
}
struct matrix Malloc(int dim){
struct matrix C;
C.A = (int **)malloc(sizeof(int*)*dim);
for(int i=0;i<dim;i++)
C.A[i] = (int *)malloc(sizeof(int)*dim);
C.col = C.row = dim;
return C;
}
struct matrix splitMatrix(int dim,int x,int y,struct matrix A){
struct matrix C = Malloc(dim);
for(int i=0;i<dim;i++)
for(int j=0;j<dim;j++)
C.A[i][j] = 0;
for(int i=x;i<x+dim&&i<A.row;i++)
for(int j=y;j<y+dim&&j<A.col;j++)
C.A[i-x][j-y] = A.A[i][j];
return C;
}
struct matrix Add(struct matrix A,struct matrix B,bool Free){
struct matrix C = Malloc(A.col);
for(int i=0;i<A.row;i++)
for(int j=0;j<A.col;j++)
C.A[i][j] = A.A[i][j]+B.A[i][j];
if(Free) {freeMatrix(A);}
return C;
}
struct matrix Sub(struct matrix A,struct matrix B,bool Free){
struct matrix C = Malloc(A.col);
for(int i=0;i<A.row;i++)
for(int j=0;j<A.col;j++)
C.A[i][j] = A.A[i][j]-B.A[i][j];
if(Free) {freeMatrix(A);}
return C;
}
struct matrix mergeMatrix(struct matrix C11,struct matrix C12,struct matrix C21,struct matrix C22){
struct matrix C = Malloc(C11.col*2);
for(int i=0;i<C11.row;i++)
{
for(int j=0;j<C11.col;j++){
C.A[i][j] = C11.A[i][j];
C.A[i+C11.row][j] = C21.A[i][j];
C.A[i][j+C11.col] = C12.A[i][j];
C.A[i+C11.row][j+C11.col] = C22.A[i][j];
}
}
return C;
}
struct matrix Strassen(struct matrix A,struct matrix B,bool freeA,bool freeB){
int dim = max(max(A.col,A.row),max(B.col,B.row));
if(dim&1&&dim!=1) dim++;
int new_dim = dim>>1;
if(A.row==1 && A.col==1 && B.row==1 && B.col==1){
struct matrix C = Malloc(dim);
C.A[0][0] = A.A[0][0]*B.A[0][0];
if(freeA)
freeMatrix(A);
if(freeB)
freeMatrix(B);
return C;
}
struct matrix A11 = splitMatrix(new_dim,0,0,A),A12 = splitMatrix(new_dim,0,new_dim,A);
struct matrix A21 = splitMatrix(new_dim,new_dim,0,A),A22 = splitMatrix(new_dim,new_dim,new_dim,A);
struct matrix B11 = splitMatrix(new_dim,0,0,B),B12 = splitMatrix(new_dim,0,new_dim,B);
struct matrix B21 = splitMatrix(new_dim,new_dim,0,B),B22 = splitMatrix(new_dim,new_dim,new_dim,B);
struct matrix P[9];
P[1] = Strassen(Add(A11,A22,0),Add(B11,B22,0),1,1);
P[2] = Strassen(Add(A21,A22,0),B11,1,0);
P[3] = Strassen(A11,Sub(B12,B22,0),0,1);
P[4] = Strassen(A22,Sub(B21,B11,0),0,1);
P[5] = Strassen(Add(A11,A12,0),B22,1,0);
P[6] = Strassen(Sub(A21,A11,0),Add(B11,B12,0),1,1);
P[7] = Strassen(Sub(A12,A22,0),Add(B21,B22,0),1,1);
struct matrix C11 = Add(Sub(Add(P[1],P[4],0),P[5],1),P[7],1);
struct matrix C12 = Add(P[3],P[5],0);
struct matrix C21 = Add(P[2],P[4],0);
struct matrix C22 = Add(Add(Sub(P[1],P[2],0),P[3],1),P[6],1);
struct matrix C = mergeMatrix(C11,C12,C21,C22);
for(int i=1;i<=7;i++)
freeMatrix(P[i]);
freeMatrix(A11);freeMatrix(A12);freeMatrix(A21);freeMatrix(A22);
freeMatrix(B11);freeMatrix(B12);freeMatrix(B21);freeMatrix(B22);
freeMatrix(C11);freeMatrix(C12);freeMatrix(C21);freeMatrix(C22);
if(freeA)
freeMatrix(A);
if(freeB)
freeMatrix(B);
return C;
}
int main(int argc, char* argv[])
{
struct timeval start,end;
fp_in = fopen(argv[1],"r");
fp_out = fopen(argv[2],"w");
fscanf(fp_in,"%d%d%d",&n,&m,&k);
Malloc0();
for(int i=0;i<n;i++)
for(int j=0;j<m;j++)
fscanf(fp_in,"%d",&X.A[i][j]);
for(int i=0;i<m;i++)
for(int j=0;j<k;j++)
fscanf(fp_in,"%d",&Y.A[i][j]);
gettimeofday(&start,NULL);
struct matrix ans = Strassen(X,Y,0,0);
gettimeofday(&end,NULL);
long long diff = 1000000 * (end.tv_sec-start.tv_sec)+ end.tv_usec-start.tv_usec;
printf("%lld",diff);
Print(ans);
freeMatrix(ans);
freeMatrix(X);
freeMatrix(Y);
return 0;
}