Strassen矩阵乘法(C++)

思路

两个矩阵A,B相乘时.有以下三种方法

暴力计算法. 三个for循环, 这时候时间复杂度为O(n^3).因为Cij=∑(k=1->n)Aik*Bkj,需要一个循环, 且C中有n^2个元素, 所以时间复杂度为O(n^3)

分治法. 首先将A,B,C分成相等大小的方块矩阵.

所以C11=A11*B11+A12*B21, C12=A11*B12+A12*B22,

C21=A21*B11+A22*B21, C22=A21*B12+A22*B22

用T(n)表示n*n矩阵的乘法, 所以有T(n)=8T(n/2)+Θ(n^2). 其中, 8T(n/2)表示8次子矩阵乘法, 子矩阵的规模为n/2 * n/2. θ(n^2)表示4次矩阵加法的时间复杂度以及合并C矩阵的时间复杂度.最后结果是Θ(n^3)与暴力计算时间复杂度相同.

Strassen算法,可以将时间复杂度优化到O(n^log7).

现在重新定义7个新矩阵

M1=(A11+A22)*(B11+B22)

M2=(A21+A22)*B11

M3=A11*(B12-B22)

M4=A22*(B21-B11)

M5=(A11+A12)*B22

M6=(A21-A11)*(B11+B12)

M7=(A12-A22)*(B21+B22)

结果矩阵C可以组合上述矩阵,如下

C11=M1+M4-M5+M7

C12=M3+M5

C21=M2+M4

C22=M1-M2+M3+M6

这时候共用了7次乘法,18次加减法运算. 写出递推公式T(n)=7T(n/2)+Θ(n^2). 最终结果是O(n^log7)=O(n^2.807).

代码如下:

#include <bits/stdc++.h>

using namespace std;

// 矩阵相乘的暴力求解
void MUL(int** MatrixA,int** MatrixB,int** MatrixResult,int Msize){
    for(int i=0;i<Msize;i++){
        for(int j=0;j<Msize;j++){
            MatrixResult[i][j]=0;
            for(int k=0;k<Msize;k++){
                MatrixResult[i][j]+=MatrixA[i][k]*MatrixB[k][j];
            }
        }
    }
}

// 矩阵相加运算
void ADD(int** MatrixA,int** MatrixB,int** MatrixResult,int Msize){
    for(int i=0;i<Msize;i++){
        for(int j=0;j<Msize;j++){
            MatrixResult[i][j]=MatrixA[i][j]+MatrixB[i][j];
        }
    }
}

// 矩阵相减运算
void SUB(int** MatrixA,int** MatrixB,int** MatrixResult,int Msize){
    for(int i=0;i<Msize;i++){
        for(int j=0;j<Msize;j++){
            MatrixResult[i][j]=MatrixA[i][j]-MatrixB[i][j];
        }
    }
}

// Strassen算法
void Strassen(int N,int** MatrixA,int** MatrixB,int** MatrixC){
    int halfSize=N/2;
    if(N<=2){
        MUL(MatrixA,MatrixB,MatrixC,N);
    }
    else{
        // 创建二维数组指针
        int** A11;
        int** A12;
        int** A21;
        int** A22;

        int** B11;
        int** B12;
        int** B21;
        int** B22;

        int** C11;
        int** C12;
        int** C21;
        int** C22;

        int** M1;
        int** M2;
        int** M3;
        int** M4;
        int** M5;
        int** M6;
        int** M7;
        int** AResult;
        int** BResult;
        // 初始化
        A11=new int*[halfSize];
        A12=new int*[halfSize];
        A21=new int*[halfSize];
        A22=new int*[halfSize];

        B11=new int*[halfSize];
        B12=new int*[halfSize];
        B21=new int*[halfSize];
        B22=new int*[halfSize];

        C11=new int*[halfSize];
        C12=new int*[halfSize];
        C21=new int*[halfSize];
        C22=new int*[halfSize];

        M1=new int*[halfSize];
        M2=new int*[halfSize];
        M3=new int*[halfSize];
        M4=new int*[halfSize];
        M5=new int*[halfSize];
        M6=new int*[halfSize];
        M7=new int*[halfSize];
        AResult=new int*[halfSize];
        BResult=new int*[halfSize];

        for(int i=0;i<halfSize;i++){
            A11[i]=new int[halfSize];
            A12[i]=new int[halfSize];
            A21[i]=new int[halfSize];
            A22[i]=new int[halfSize];

            B11[i]=new int[halfSize];
            B12[i]=new int[halfSize];
            B21[i]=new int[halfSize];
            B22[i]=new int[halfSize];

            C11[i]=new int[halfSize];
            C12[i]=new int[halfSize];
            C21[i]=new int[halfSize];
            C22[i]=new int[halfSize];

            M1[i]=new int[halfSize];
            M2[i]=new int[halfSize];
            M3[i]=new int[halfSize];
            M4[i]=new int[halfSize];
            M5[i]=new int[halfSize];
            M6[i]=new int[halfSize];
            M7[i]=new int[halfSize];

            AResult[i]=new int[halfSize];
            BResult[i]=new int[halfSize];
        }

        // 把MatrixA和MatrixB分块
        for(int i=0;i<N/2;i++){
            for(int j=0;j<N/2;j++){
                A11[i][j]=MatrixA[i][j];
                A12[i][j]=MatrixA[i][j+N/2];
                A21[i][j]=MatrixA[i+N/2][j];
                A22[i][j]=MatrixA[i+N/2][j+N/2];

                B11[i][j]=MatrixB[i][j];
                B12[i][j]=MatrixB[i][j+N/2];
                B21[i][j]=MatrixB[i+N/2][j];
                B22[i][j]=MatrixB[i+N/2][j+N/2];
            }
        }

        // M1=(A11+A22)*(B11+B22)
        ADD(A11,A22,AResult,halfSize);
        ADD(B11,B22,BResult,halfSize);
        Strassen(halfSize,AResult,BResult,M1);

        // M2=(A21+A22)*B11
        ADD(A21,A22,AResult,halfSize);
        Strassen(halfSize,AResult,B11,M2);

        // M3=A11*(B12-B22)
        SUB(B12,B22,BResult,halfSize);
        Strassen(halfSize,A11,BResult,M3);

        // M4=A22*(B21-B11)
        SUB(B21,B11,BResult,halfSize);
        Strassen(halfSize,A22,BResult,M4);

        // M5=(A11+A12)B22
        ADD( A11, A12, AResult, halfSize);
        Strassen(halfSize, AResult, B22, M5);

        // M6=(A21-A11)*(B11+B12)
        SUB( A21, A11, AResult, halfSize);
        ADD( B11, B12, BResult, halfSize);
        Strassen( halfSize, AResult, BResult, M6);

        // M7=(A12-A22)*(B21+B22)
        SUB(A12, A22, AResult, halfSize);
        ADD(B21, B22, BResult, halfSize);
        Strassen(halfSize, AResult, BResult, M7);

        // C11=M1+M4-M5+M7
        ADD( M1, M4, AResult, halfSize);
        SUB( M7, M5, BResult, halfSize);
        ADD( AResult, BResult, C11, halfSize);

        // C12=M3+M5
        ADD( M3, M5, C12, halfSize);

        // C21=M2+M4
        ADD( M2, M4, C21, halfSize);

        // C22=M1-M2+M3+M6
        ADD( M1, M3, AResult, halfSize);
        SUB( M6, M2, BResult, halfSize);
        ADD( AResult, BResult, C22, halfSize);

        // 把C11,C12,C21,C22矩阵合并成一个大矩阵MatrixC
        for(int i=0;i<N/2;i++){
            for(int j=0;j<N/2;j++){
                MatrixC[i][j]=C11[i][j];
                MatrixC[i][j+N/2]=C12[i][j];
                MatrixC[i+N/2][j]=C21[i][j];
                MatrixC[i+N/2][j+N/2]=C22[i][j];
            }
        }

        // 释放空间
        for (int i = 0; i < halfSize; i++)
        {
            delete[] A11[i];delete[] A12[i];delete[] A21[i];
            delete[] A22[i];

            delete[] B11[i];delete[] B12[i];delete[] B21[i];
            delete[] B22[i];
            delete[] C11[i];delete[] C12[i];delete[] C21[i];
            delete[] C22[i];
            delete[] M1[i];delete[] M2[i];delete[] M3[i];delete[] M4[i];
            delete[] M5[i];delete[] M6[i];delete[] M7[i];
            delete[] AResult[i];delete[] BResult[i] ;
        }
        delete[] A11;delete[] A12;delete[] A21;delete[] A22;
        delete[] B11;delete[] B12;delete[] B21;delete[] B22;
        delete[] C11;delete[] C12;delete[] C21;delete[] C22;
        delete[] M1;delete[] M2;delete[] M3;delete[] M4;delete[] M5;
        delete[] M6;delete[] M7;
        delete[] AResult;
        delete[] BResult;
    }

}

int main()
{
    int MSize;
    cin >> MSize;

    // 定义三个矩阵
    int** MatrixA;
    int** MatrixB;
    int** MatrixC;

    // 初始化三个矩阵
    MatrixA=new int*[MSize];
    MatrixB=new int*[MSize];
    MatrixC=new int*[MSize];
    for(int i=0;i<MSize;i++){
        MatrixA[i]=new int[MSize];
        MatrixB[i]=new int[MSize];
        MatrixC[i]=new int[MSize];
    }

    // 输入相乘的矩阵
    for(int i=0;i<MSize;i++){
        for(int j=0;j<MSize;j++){
            cin >> MatrixA[i][j];
        }
    }
    for(int i=0;i<MSize;i++){
        for(int j=0;j<MSize;j++){
            cin >> MatrixB[i][j];
        }
    }

    Strassen(MSize,MatrixA,MatrixB,MatrixC);

    // 打印输出结果矩阵
    for(int i=0;i<MSize;i++){
        for(int j=0;j<MSize;j++){
            cout << MatrixC[i][j] << " ";
        }
        cout << endl;
    }

    return 0;
}


/* 一组数据
4
1 2 4 7
8 3 6 5
4 7 2 1
6 4 3 1
1 2 4 7
8 3 6 5
4 7 2 1
6 4 3 1
*/

  • 11
    点赞
  • 61
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
以下是使用C语言实现的Strassen标准矩阵乘法算法的代码: ``` #include <stdio.h> #define N 4 // 矩阵大小 #define THRESHOLD 2 // 阈值 // 矩阵加法 void add_matrix(int a[][N], int b[][N], int c[][N]) { int i, j; for (i = 0; i < N; i++) { for (j = 0; j < N; j++) { c[i][j] = a[i][j] + b[i][j]; } } } // 矩阵减法 void sub_matrix(int a[][N], int b[][N], int c[][N]) { int i, j; for (i = 0; i < N; i++) { for (j = 0; j < N; j++) { c[i][j] = a[i][j] - b[i][j]; } } } // Strassen矩阵乘法 void strassen_mul(int a[][N], int b[][N], int c[][N]) { // 达到阈值,使用标准矩阵乘法 if (N <= THRESHOLD) { int i, j, k; for (i = 0; i < N; i++) { for (j = 0; j < N; j++) { c[i][j] = 0; for (k = 0; k < N; k++) { c[i][j] += a[i][k] * b[k][j]; } } } return; } // 处理矩阵的大小并向上取整 int size = N / 2; if (N % 2 != 0) { size += 1; } int A[size][size], B[size][size], C[size][size], D[size][size]; int E[size][size], F[size][size], G[size][size], H[size][size]; int P1[size][size], P2[size][size], P3[size][size], P4[size][size], P5[size][size], P6[size][size], P7[size][size]; int tmp1[size][size], tmp2[size][size]; // 拆分矩阵 int i, j; for (i = 0; i < size; i++) { for (j = 0; j < size; j++) { A[i][j] = a[i][j]; B[i][j] = a[i][j + size]; C[i][j] = a[i + size][j]; D[i][j] = a[i + size][j + size]; E[i][j] = b[i][j]; F[i][j] = b[i][j + size]; G[i][j] = b[i + size][j]; H[i][j] = b[i + size][j + size]; } } // 计算P1到P7 sub_matrix(F, H, tmp1); strassen_mul(A, tmp1, P1); add_matrix(A, B, tmp1); strassen_mul(tmp1, H, P2); add_matrix(C, D, tmp1); strassen_mul(tmp1, E, P3); sub_matrix(G, E, tmp1); strassen_mul(D, tmp1, P4); add_matrix(A, D, tmp1); add_matrix(E, H, tmp2); strassen_mul(tmp1, tmp2, P5); sub_matrix(B, D, tmp1); add_matrix(G, H, tmp2); strassen_mul(tmp1, tmp2, P6); sub_matrix(A, C, tmp1); add_matrix(E, F, tmp2); strassen_mul(tmp1, tmp2, P7); // 计算结果矩阵 add_matrix(P5, P4, tmp1); sub_matrix(tmp1, P2, tmp2); add_matrix(tmp2, P6, c[0]); add_matrix(P1, P2, c[1]); add_matrix(P3, P4, c[2]); add_matrix(P5, P1, tmp1); sub_matrix(tmp1, P3, tmp2); sub_matrix(tmp2, P7, c[3]); } int main() { int a[N][N] = {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}; int b[N][N] = {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}; int c[N][N]; int i, j; strassen_mul(a, b, c); printf("Result:\n"); for (i = 0; i < N; i++) { for (j = 0; j < N; j++) { printf("%d ", c[i][j]); } printf("\n"); } return 0; } ``` 本代码中定义了THRESHOLD变量,当矩阵大小小于等于阈值时,使用标准矩阵乘法算法计算。简单起见,本代码中矩阵大小固定为4 * 4,可以根据需要修改。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值