strassen矩阵乘法

算法课实验

#include "iostream"
#include "ctime"
#include "cstdlib"
using namespace std;
clock_t time_start,time_end;
int **mul(int**a,int**b,int m){
    int **c = new int*[m];
    for(int i=0;i<m;i++){
        c[i] = new int[m];
        for(int j=0;j<m;j++){
            c[i][j] = 0;
        }
    }
    for(int i=0;i<m;i++){
        for(int j=0;j<m;j++){
            for(int k=0;k<m;k++){
                c[i][j] += a[i][k]*b[k][j];
            }
        }
    }
    return c;
}
int **strassen_add(int **a,int **b,int m){
    int **c = new int*[m];
    for(int i=0;i<m;i++){
        c[i] = new int[m];
        for(int j=0;j<m;j++){
            c[i][j] = 0;
        }
    }
    for(int i=0;i<m;i++)for(int j=0;j<m;j++){
        c[i][j] = a[i][j]+b[i][j];
    }
    return c;
}
int **strassen_mul(int **a,int **b,int m){
    int **c = new int*[m];
    for(int i=0;i<m;i++){
        c[i] = new int[m];
    }
    int **a1,**a2,**a3,**a4;
    int **b1,**b2,**b3,**b4;
    a1 = new int*[m/2];
    a2 = new int*[m/2];
    a3 = new int*[m/2];
    a4 = new int*[m/2];
    b1 = new int*[m/2];
    b2 = new int*[m/2];
    b3 = new int*[m/2];
    b4 = new int*[m/2];
    for(int i=0;i<m/2;i++){
        a1[i] = new int[m/2];
        a2[i] = new int[m/2];
        a3[i] = new int[m/2];
        a4[i] = new int[m/2];
        b1[i] = new int[m/2];
        b2[i] = new int[m/2];
        b3[i] = new int[m/2];
        b4[i] = new int[m/2];
    }
    for(int i=0;i<m/2;i++){
        for(int j=0;j<m/2;j++){
            a1[i][j] = a[i][j];
            b1[i][j] = b[i][j];
            a2[i][j] = a[i][j+m/2];
            b2[i][j] = b[i][j+m/2];
            a3[i][j] = a[i+m/2][j];
            b3[i][j] = b[i+m/2][j];
            a4[i][j] = a[i+m/2][j+m/2];
            b4[i][j] = b[i+m/2][j+m/2];
        }
    }
    int **m1,**m2,**m3,**m4,**m5,**m6,**m7;
    if(m>=64 and m%2==0) {
        m1 = strassen_mul(a1, strassen_cut(b2, b4, m / 2), m / 2);
        m2 = strassen_mul(strassen_add(a1, a2, m / 2), b4, m / 2);
        m3 = strassen_mul(strassen_add(a3, a4, m / 2), b1, m / 2);
        m4 = strassen_mul(a4, strassen_cut(b3, b1, m / 2), m / 2);
        m5 = strassen_mul(strassen_add(a1, a4, m / 2), strassen_add(b1, b4, m / 2), m / 2);
        m6 = strassen_mul(strassen_cut(a2, a4, m / 2), strassen_add(b3, b4, m / 2), m / 2);
        m7 = strassen_mul(strassen_cut(a1, a3, m / 2), strassen_add(b1, b2, m / 2), m / 2);
    } else{
        m1 = mul(a1, strassen_cut(b2, b4, m / 2), m / 2);
        m2 = mul(strassen_add(a1, a2, m / 2), b4, m / 2);
        m3 = mul(strassen_add(a3, a4, m / 2), b1, m / 2);
        m4 = mul(a4, strassen_cut(b3, b1, m / 2), m / 2);
        m5 = mul(strassen_add(a1, a4, m / 2), strassen_add(b1, b4, m / 2), m / 2);
        m6 = mul(strassen_cut(a2, a4, m / 2), strassen_add(b3, b4, m / 2), m / 2);
        m7 = mul(strassen_cut(a1, a3, m / 2), strassen_add(b1, b2, m / 2), m / 2);
    }
    int **c1 = strassen_add(strassen_cut(strassen_add(m5,m4,m/2),m2,m/2),m6,m/2);
    int **c2 = strassen_add(m1,m2,m/2);
    int **c3 = strassen_add(m3,m4,m/2);
    int **c4 = strassen_cut(strassen_cut(strassen_add(m5,m1,m/2),m3,m/2),m7,m/2);
    for (int i=0;i<m/2;i++){
        for(int j=0;j<m/2;j++){
            c[i][j] = c1[i][j];
            c[i][j+m/2] = c2[i][j];
            c[i+m/2][j] = c3[i][j];
            c[i+m/2][j+m/2] = c4[i][j];
        }
    }
    return c;
}
int main(){
    int m;
    cin >> m ;
    int **A = new int*[m];
    int **B = new int*[m];
    for(int i=0;i<m;i++){A[i] = new int [m];}
    for(int i=0;i<m;i++){B[i] = new int [m];}
    for(int i=0;i<m;i++) for(int j=0;j<m;j++) A[i][j] = rand()%1000;
    for(int i=0;i<m;i++) for(int j=0;j<m;j++) B[i][j] = rand()%1000;
    FILE *fileA = fopen("matriaA.txt","w");
    FILE *fileB = fopen("matriaB.txt","w");
    for(int i=0;i<m;i++){
        for (int j=0;j<m;j++) {
            fprintf(fileA,"%d ",A[i][j]);
        }
        fprintf(fileA,"\n");
    }
    for(int i=0;i<m;i++){
        for (int j=0;j<m;j++) {
            fprintf(fileB,"%d ",B[i][j]);
        }
        fprintf(fileB,"\n");
    }
    fclose(fileA);
    fclose(fileB);
    time_start = clock();
    int **C= strassen_mul(A,B,m);
    //for(int i=0;i<m;i++){
    //    for (int j=0;j<m;j++) {
    //        cout<<C[i][j]<<" ";
    //    }
    //    cout<<endl;
    //}
    time_end = clock();
    cout<<"strassen time:"<<time_end-time_start<<endl;
    time_start = clock();
    C = mul(A,B,m);
    //for(int i=0;i<m;i++){
    //    for (int j=0;j<m;j++) {
    //        cout<<C[i][j]<<" ";
    //    }
    //    cout<<endl;
    //}
    time_end = clock();
    cout<<"normal time:"<<time_end-time_start<<endl;
    return 0;
}

256*256的矩阵相乘
256*256的矩阵相乘结果

在这里插入图片描述
压力测试
在这里插入图片描述
递归之后的压力测试

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值