算法课实验
#include "iostream"
#include "ctime"
#include "cstdlib"
using namespace std;
clock_t time_start,time_end;
int **mul(int**a,int**b,int m){
int **c = new int*[m];
for(int i=0;i<m;i++){
c[i] = new int[m];
for(int j=0;j<m;j++){
c[i][j] = 0;
}
}
for(int i=0;i<m;i++){
for(int j=0;j<m;j++){
for(int k=0;k<m;k++){
c[i][j] += a[i][k]*b[k][j];
}
}
}
return c;
}
int **strassen_add(int **a,int **b,int m){
int **c = new int*[m];
for(int i=0;i<m;i++){
c[i] = new int[m];
for(int j=0;j<m;j++){
c[i][j] = 0;
}
}
for(int i=0;i<m;i++)for(int j=0;j<m;j++){
c[i][j] = a[i][j]+b[i][j];
}
return c;
}
int **strassen_mul(int **a,int **b,int m){
int **c = new int*[m];
for(int i=0;i<m;i++){
c[i] = new int[m];
}
int **a1,**a2,**a3,**a4;
int **b1,**b2,**b3,**b4;
a1 = new int*[m/2];
a2 = new int*[m/2];
a3 = new int*[m/2];
a4 = new int*[m/2];
b1 = new int*[m/2];
b2 = new int*[m/2];
b3 = new int*[m/2];
b4 = new int*[m/2];
for(int i=0;i<m/2;i++){
a1[i] = new int[m/2];
a2[i] = new int[m/2];
a3[i] = new int[m/2];
a4[i] = new int[m/2];
b1[i] = new int[m/2];
b2[i] = new int[m/2];
b3[i] = new int[m/2];
b4[i] = new int[m/2];
}
for(int i=0;i<m/2;i++){
for(int j=0;j<m/2;j++){
a1[i][j] = a[i][j];
b1[i][j] = b[i][j];
a2[i][j] = a[i][j+m/2];
b2[i][j] = b[i][j+m/2];
a3[i][j] = a[i+m/2][j];
b3[i][j] = b[i+m/2][j];
a4[i][j] = a[i+m/2][j+m/2];
b4[i][j] = b[i+m/2][j+m/2];
}
}
int **m1,**m2,**m3,**m4,**m5,**m6,**m7;
if(m>=64 and m%2==0) {
m1 = strassen_mul(a1, strassen_cut(b2, b4, m / 2), m / 2);
m2 = strassen_mul(strassen_add(a1, a2, m / 2), b4, m / 2);
m3 = strassen_mul(strassen_add(a3, a4, m / 2), b1, m / 2);
m4 = strassen_mul(a4, strassen_cut(b3, b1, m / 2), m / 2);
m5 = strassen_mul(strassen_add(a1, a4, m / 2), strassen_add(b1, b4, m / 2), m / 2);
m6 = strassen_mul(strassen_cut(a2, a4, m / 2), strassen_add(b3, b4, m / 2), m / 2);
m7 = strassen_mul(strassen_cut(a1, a3, m / 2), strassen_add(b1, b2, m / 2), m / 2);
} else{
m1 = mul(a1, strassen_cut(b2, b4, m / 2), m / 2);
m2 = mul(strassen_add(a1, a2, m / 2), b4, m / 2);
m3 = mul(strassen_add(a3, a4, m / 2), b1, m / 2);
m4 = mul(a4, strassen_cut(b3, b1, m / 2), m / 2);
m5 = mul(strassen_add(a1, a4, m / 2), strassen_add(b1, b4, m / 2), m / 2);
m6 = mul(strassen_cut(a2, a4, m / 2), strassen_add(b3, b4, m / 2), m / 2);
m7 = mul(strassen_cut(a1, a3, m / 2), strassen_add(b1, b2, m / 2), m / 2);
}
int **c1 = strassen_add(strassen_cut(strassen_add(m5,m4,m/2),m2,m/2),m6,m/2);
int **c2 = strassen_add(m1,m2,m/2);
int **c3 = strassen_add(m3,m4,m/2);
int **c4 = strassen_cut(strassen_cut(strassen_add(m5,m1,m/2),m3,m/2),m7,m/2);
for (int i=0;i<m/2;i++){
for(int j=0;j<m/2;j++){
c[i][j] = c1[i][j];
c[i][j+m/2] = c2[i][j];
c[i+m/2][j] = c3[i][j];
c[i+m/2][j+m/2] = c4[i][j];
}
}
return c;
}
int main(){
int m;
cin >> m ;
int **A = new int*[m];
int **B = new int*[m];
for(int i=0;i<m;i++){A[i] = new int [m];}
for(int i=0;i<m;i++){B[i] = new int [m];}
for(int i=0;i<m;i++) for(int j=0;j<m;j++) A[i][j] = rand()%1000;
for(int i=0;i<m;i++) for(int j=0;j<m;j++) B[i][j] = rand()%1000;
FILE *fileA = fopen("matriaA.txt","w");
FILE *fileB = fopen("matriaB.txt","w");
for(int i=0;i<m;i++){
for (int j=0;j<m;j++) {
fprintf(fileA,"%d ",A[i][j]);
}
fprintf(fileA,"\n");
}
for(int i=0;i<m;i++){
for (int j=0;j<m;j++) {
fprintf(fileB,"%d ",B[i][j]);
}
fprintf(fileB,"\n");
}
fclose(fileA);
fclose(fileB);
time_start = clock();
int **C= strassen_mul(A,B,m);
//for(int i=0;i<m;i++){
// for (int j=0;j<m;j++) {
// cout<<C[i][j]<<" ";
// }
// cout<<endl;
//}
time_end = clock();
cout<<"strassen time:"<<time_end-time_start<<endl;
time_start = clock();
C = mul(A,B,m);
//for(int i=0;i<m;i++){
// for (int j=0;j<m;j++) {
// cout<<C[i][j]<<" ";
// }
// cout<<endl;
//}
time_end = clock();
cout<<"normal time:"<<time_end-time_start<<endl;
return 0;
}
256*256的矩阵相乘结果
压力测试
递归之后的压力测试