Strassen矩阵乘法

// Created on iPad spades.

#include <iostream>
using namespace std;


void ADD(int **MatrixA, int **MatrixB, int **MatrixResult, int MatrixSize){
    for(int i = 0; i < MatrixSize; i++){
        for(int j = 0; j < MatrixSize; j++){
            MatrixResult[i][j] = MatrixA[i][j] + MatrixB[i][j];
        }
    }
}
void MUL(int **MatrixA, int **MatrixB, int **MatrixResult, int MatrixSize){
    for(int i = 0; i < MatrixSize; i++){
        for(int j = 0; j < MatrixSize; j++){
            MatrixResult[i][j] = 0;
            for(int k = 0; k < MatrixSize; k++){
                MatrixResult[i][j] += MatrixA[i][k] * MatrixB[k][j];
            }
        }
    }
}
void SUB(int **MatrixA, int **MatrixB, int **MatrixResult, int MatrixSize){
    for(int i = 0; i < MatrixSize; i++){
        for(int j = 0; j < MatrixSize; j++){
            MatrixResult[i][j] = MatrixA[i][j] - MatrixB[i][j];
        }
    }
}
void Strassen(int N, int **MatrixA,int **MatrixB, int **MatrixC){

    int HalfSize = N/2;
    int newSize = N/2;

    if(N <= 64){
        MUL(MatrixA, MatrixB, MatrixC, N);
    }else{
        int **A11 = new int *[newSize];
        int **A12 = new int *[newSize];
        int **A21 = new int *[newSize];
        int **A22 = new int *[newSize];

        int **B11 = new int *[newSize];
        int **B12 = new int *[newSize];
        int **B21 = new int *[newSize];
        int **B22 = new int *[newSize];

        int **C11 = new int *[newSize];
        int **C12 = new int *[newSize];
        int **C21 = new int *[newSize];
        int **C22 = new int *[newSize];

        int **M1 = new int *[newSize];
        int **M2 = new int *[newSize];
        int **M3 = new int *[newSize];
        int **M4 = new int *[newSize];
        int **M5 = new int *[newSize];
        int **M6 = new int *[newSize];
        int **M7 = new int *[newSize];

        int **AResult = new int *[newSize];
        int **BResult = new int *[newSize];

        int newLength = newSize;

        for(int i = 0; i < newSize; i++){
            A11[i] = new int[newLength];
            A12[i] = new int[newLength];
            A21[i] = new int[newLength];
            A22[i] = new int[newLength];

            B11[i] = new int[newLength];
            B12[i] = new int[newLength];
            B21[i] = new int[newLength];
            B22[i] = new int[newLength];

            C11[i] = new int[newLength];
            C12[i] = new int[newLength];
            C21[i] = new int[newLength];
            C22[i] = new int[newLength];

            M1[i] = new int[newLength];
            M2[i] = new int[newLength];
            M3[i] = new int[newLength];
            M4[i] = new int[newLength];
            M5[i] = new int[newLength];
            M6[i] = new int[newLength];
            M7[i] = new int[newLength];

            AResult[i] = new int[newLength];
            BResult[i] = new int[newLength];
        }

        for(int i = 0; i < N/2; i++){
            for(int j = 0; j < N/2; j++){
                A11[i][j] = MatrixA[i][j];
                A12[i][j] = MatrixA[i][j+N/2];
                A21[i][j] = MatrixA[i+N/2][j];
                A22[i][j] = MatrixA[i+N/2][j+N/2];

                B11[i][j] = MatrixB[i][j];
                B12[i][j] = MatrixB[i][j+N/2];
                B21[i][j] = MatrixB[i+N/2][j];
                B22[i][j] = MatrixB[i+N/2][j+N/2];
            }
        }

        // M1
        ADD(A11, A22, AResult, HalfSize);
        ADD(B11, B22, BResult, HalfSize);
        Strassen(HalfSize, AResult, BResult, M1);

        // M2
        ADD(A21, A22, AResult, HalfSize);
        Strassen(HalfSize, AResult, B11, M2);

        // M3
        SUB(B12, B22, BResult, HalfSize);
        Strassen(HalfSize, A11, BResult, M3);

        // M4
        SUB(B21, B11, BResult, HalfSize);
        Strassen(HalfSize, A22, BResult, M4);

        // M5
        ADD(A11, A12, AResult, HalfSize);
        Strassen(HalfSize, AResult, B22, M5);

        // M6
        SUB(A21, A11, AResult, HalfSize);
        ADD(B11, B12, BResult, HalfSize);
        Strassen(HalfSize, AResult, BResult, M6);

        // M7
        SUB(A12, A22, AResult, HalfSize);
        ADD(B21, B22, BResult, HalfSize);
        Strassen(HalfSize, AResult, BResult, M7);

        // C11
        ADD(M1, M4, AResult, HalfSize);
        SUB(AResult, M5, BResult, HalfSize);
        ADD(BResult, M7, C11, HalfSize);

        // C12
        ADD(M3, M5, C12, HalfSize);

        // C21
        ADD(M2, M4, C21, HalfSize);

        // C22
        ADD(M1, M3, AResult, HalfSize);
        SUB(AResult, M2, BResult, HalfSize);
        ADD(BResult, M6, C22, HalfSize);

        // MatrixC
        for(int i = 0; i < N/2; i++){
            for(int j = 0; j < N/2; j++){
                MatrixC[i][j] = C11[i][j];
                MatrixC[i][j+N/2] = C12[i][j];
                MatrixC[i+N/2][j] = C21[i][j];
                MatrixC[i+N/2][j+N/2] = C22[i][j];
            }
        }

        //delete
        for(int i = 0; i < newLength; i++){
            delete[] A11[i];
            delete[] A12[i];
            delete[] A21[i];
            delete[] A22[i];

            delete[] B11[i];
            delete[] B12[i];
            delete[] B21[i];
            delete[] B22[i];

            delete[] C11[i];
            delete[] C12[i];
            delete[] C21[i];
            delete[] C22[i];

            delete[] M1[i];
            delete[] M2[i];
            delete[] M3[i];
            delete[] M4[i];
            delete[] M5[i];
            delete[] M6[i];
            delete[] M7[i];

            delete[] AResult[i];
            delete[] BResult[i];
        }
        delete[] A11;
        delete[] A12;
        delete[] A21;
        delete[] A22;

        delete[] B11;
        delete[] B12;
        delete[] B21;
        delete[] B22;

        delete[] C11;
        delete[] C12;
        delete[] C21;
        delete[] C22;

        delete[] M1;
        delete[] M2;
        delete[] M3;
        delete[] M4;
        delete[] M5;
        delete[] M6;
        delete[] M7;

        delete[] AResult;
        delete[] BResult;
    }
}
void PrintMatrix(int **MatrixA, int MatrixSize){
    cout<<endl;
    for(int row = 0; row < MatrixSize; row++){
        for(int column = 0; column < MatrixSize; column++){
            cout<<MatrixA[row][column]<<'\t';
            if((column+1) % MatrixSize == 0)
                cout<<endl;
        }
    }
}
void FillMatrix(int **MatrixA, int MatrixSize, int n){
    for(int i = 0; i < MatrixSize; i++){
        for(int j = 0; j < MatrixSize; j++){
            MatrixA[i][j] = n;
        }
    }
}
int main() {
    int MatrixSize = 100;
    int **MatrixA;
    int **MatrixB;
    int **MatrixC;

    MatrixA = new int *[MatrixSize];
    MatrixB = new int *[MatrixSize];
    MatrixC = new int *[MatrixSize];

    for(int i = 0; i < MatrixSize; i++){
        MatrixA[i] = new int [MatrixSize];
        MatrixB[i] = new int [MatrixSize];
        MatrixC[i] = new int [MatrixSize];
    }

    FillMatrix(MatrixA, MatrixSize, 1);
    FillMatrix(MatrixB, MatrixSize, 2);

    Strassen(MatrixSize, MatrixA, MatrixB, MatrixC);

    PrintMatrix(MatrixC, MatrixSize);
    return 0;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值