Strassen快速矩阵乘法的C++实现

Strassen矩阵乘法是一个时间复杂度小于 O ( N 3 ) O(N^3) O(N3)的矩阵乘法算法,通过把朴素分治矩阵乘法的八次对子矩阵乘积计算的递归操作通过一系列优化技巧减少到七次,从而得到少于 O ( N 3 ) O(N^3) O(N3)的时间复杂度,该算法的详细讲解参见算法导论(第三版)中译本p43-p47这里只给出C++实现而不重复阐述
C++代码:

#include <iostream>
#include <vector>
#include <random>
using namespace std;

void matrixAdd(vector<vector<int>>& left, size_t left_up_x, size_t left_up_y, size_t left_down_x, size_t left_down_y,vector<vector<int>>& right, size_t right_up_x, size_t right_up_y, vector<vector<int>>& result)
{
    for (size_t i = 0; i <= left_down_y - left_up_y; ++i)
    {
        for (size_t j = 0; j <= left_down_x - left_up_x; ++j)
        {
            result[i][j] = left[i + left_up_y][j + left_up_x] + right[i + right_up_y][j + right_up_x];
        }
    }
}

void matrixSub(vector<vector<int>>& left, size_t left_up_x, size_t left_up_y, size_t left_down_x, size_t left_down_y, vector<vector<int>>& right, size_t right_up_x, size_t right_up_y, vector<vector<int>>& result)
{
    for (size_t i = 0; i <= left_down_y - left_up_y; ++i)
    {
        for (size_t j = 0; j <= left_down_x - left_up_x; ++j)
        {
            result[i][j] = left[i + left_up_y][j + left_up_x] - right[i + right_up_y][j + right_up_x];
        }
    }
}

void extract(vector<vector<int>>& left, size_t left_up_x, size_t left_up_y, size_t left_down_x, size_t left_down_y, vector<vector<int>>& result)
{
    for (size_t i = 0; i <= left_down_y - left_up_y; ++i)
    {
        for (size_t j = 0; j <= left_down_x - left_up_x; ++j)
        {
            result[i][j] = left[i + left_up_y][j + left_up_x];
        }
    }
}

void add(vector<vector<int>>& be_added, size_t left_up_x, size_t left_up_y, size_t left_down_x, size_t left_down_y, vector<vector<int>>& left)
{
    for (size_t i = 0; i + left_up_y <= left_down_y; ++i)
    {
        for (size_t j = 0; j + left_up_x <= left_down_x; ++j)
        {
            be_added[i + left_up_y][j + left_up_x] = left[i][j];
        }
    }
}

void Strassen(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &P)
{
    if (A.size() == 1 && B.size() == 1 && B[0].size() == 1)
    {
        P[0][0] = A[0][0] * B[0][0];
        return;
    }
    
    size_t s = A.size();
    size_t m = B[0].size();
    if (A.size() % 2 == 1)
    {
        A.push_back(vector<int>(B.size(), 0));
    }

    if (B.size() % 2 == 1)
    {
        for (size_t i = 0; i < A.size(); ++i)
        {
            A[i].push_back(0);
        }

        B.push_back(vector<int>(B[0].size(), 0));
    }

    if (B[0].size() % 2 == 1)
    {
        for (size_t j = 0; j < B.size(); ++j)
            B[j].push_back(0);
    }

    vector<vector<int>> S1(B.size()/2, vector<int>(B[0].size()/2, 0));
    matrixSub(B, B[0].size()/2, 0, B[0].size() - 1, B.size()/2 - 1, B, B[0].size()/2, B.size()/2, S1);
    vector<vector<int>> S2(A.size()/2, vector<int>(B.size()/2, 0));
    matrixAdd(A, 0, 0, B.size()/2 - 1, A.size()/2 - 1, A, B.size()/2, 0, S2);
    vector<vector<int>> S3(A.size() / 2, vector<int>(B.size() / 2, 0));
    matrixAdd(A, 0, A.size() / 2, B.size() / 2 - 1, A.size() - 1, A, B.size() / 2, A.size() / 2, S3);
    vector<vector<int>> S4(B.size() / 2, vector<int>(B[0].size() / 2, 0));
    matrixSub(B, 0, B.size()/2, B[0].size()/2 - 1, B.size() - 1, B, 0, 0, S4);
    vector<vector<int>> S5(A.size() / 2, vector<int>(B.size() / 2, 0));
    matrixAdd(A, 0, 0, B.size()/2 - 1, A.size()/2 - 1, A, B.size()/2, A.size()/2, S5);
    vector<vector<int>> S6(B.size() / 2, vector<int>(B[0].size() / 2, 0));
    matrixAdd(B, 0, 0, B[0].size()/2 - 1, B.size()/2 - 1, B, B[0].size()/2, B.size()/2, S6);
    vector<vector<int>> S7(A.size() / 2, vector<int>(B.size() / 2, 0));
    matrixSub(A, B.size()/2, 0, B.size()-1, A.size()/2 -1, A, B.size()/2, A.size()/2, S7);
    vector<vector<int>> S8(B.size() / 2, vector<int>(B[0].size() / 2, 0));
    matrixAdd(B, 0, B.size() / 2, B[0].size() / 2 - 1, B.size() - 1, B, B[0].size() / 2, B.size() / 2, S8);
    vector<vector<int>> S9(A.size() / 2, vector<int>(B.size() / 2, 0));
    matrixSub(A, 0, 0, B.size() / 2 - 1, A.size() / 2 - 1, A, 0, A.size() / 2, S9);
    vector<vector<int>> S10(B.size() / 2, vector<int>(B[0].size() / 2, 0));
    matrixAdd(B, 0, 0, B[0].size() / 2 - 1, B.size() / 2 - 1, B, B[0].size() / 2, 0, S10);

    vector<vector<int>> A11(A.size()/2, vector<int>(B.size()/2, 0));
    extract(A, 0, 0, B.size() / 2 - 1, A.size() / 2 - 1, A11);
    vector<vector<int>> P1(A.size() / 2, vector<int>(B[0].size() / 2, 0));
    Strassen(A11, S1, P1);

    vector<vector<int>> B22(B.size() / 2, vector<int>(B[0].size() / 2, 0));
    extract(B, B[0].size()/2, B.size()/2, B[0].size() - 1, B.size() - 1, B22);
    vector<vector<int>> P2(A.size() / 2, vector<int>(B[0].size() / 2, 0));
    Strassen(S2, B22, P2);

    vector<vector<int>> B11(B.size() / 2, vector<int>(B[0].size() / 2, 0));
    extract(B, 0, 0, B[0].size() / 2 - 1, B.size() / 2 - 1, B11);
    vector<vector<int>> P3(A.size() / 2, vector<int>(B[0].size() / 2, 0));
    Strassen(S3, B11, P3);

    vector<vector<int>> A22(A.size() / 2, vector<int>(B.size() / 2, 0));
    extract(A, B.size() / 2, A.size() / 2, B.size() - 1, A.size() - 1, A22);
    vector<vector<int>> P4(A.size() / 2, vector<int>(B[0].size() / 2, 0));
    Strassen(A22, S4, P4);

    vector<vector<int>> P5(A.size() / 2, vector<int>(B[0].size() / 2, 0));
    Strassen(S5, S6, P5);

    vector<vector<int>> P6(A.size() / 2, vector<int>(B[0].size() / 2, 0));
    Strassen(S7, S8, P6);

    vector<vector<int>> P7(A.size() / 2, vector<int>(B[0].size() / 2, 0));
    Strassen(S9, S10, P7);

    vector<vector<int>> C11(A.size() / 2, vector<int>(B[0].size() / 2, 0));
    matrixAdd(P5, 0, 0, P5[0].size() - 1, P5.size() - 1, P4, 0, 0, C11);
    matrixSub(C11, 0, 0, C11[0].size() - 1, C11.size() - 1, P2, 0, 0, C11);
    matrixAdd(C11, 0, 0, C11[0].size() - 1, C11.size() - 1, P6, 0, 0, C11);
    add(P, 0, 0, C11[0].size() - 1, C11.size() - 1, C11);

    vector<vector<int>> C12(A.size() / 2, vector<int>(B[0].size() / 2, 0));
    matrixAdd(P1, 0, 0, P1[0].size() - 1, P1.size() - 1, P2, 0, 0, C12);
    add(P, C11[0].size(), 0, P[0].size() - 1, C11.size() - 1, C12);

    vector<vector<int>> C21(A.size() / 2, vector<int>(B[0].size() / 2, 0));
    matrixAdd(P3, 0, 0, P3[0].size() - 1, P3.size() - 1, P4, 0, 0, C21);
    add(P, 0, C11.size(), C11[0].size() - 1, P.size() - 1, C21);

    vector<vector<int>> C22(A.size() / 2, vector<int>(B[0].size() / 2, 0));
    matrixAdd(P5, 0, 0, P5[0].size() - 1, P5.size() - 1, P1, 0, 0, C22);
    matrixSub(C22, 0, 0, C22[0].size() - 1, C22.size() - 1, P3, 0, 0, C22);
    matrixSub(C22, 0, 0, C22[0].size() - 1, C22.size() - 1, P7, 0, 0, C22);
    add(P, C11[0].size(), C11.size(), P[0].size() - 1, P.size() - 1, C22);
}

void matrixMultiply(vector<vector<int>>& left, vector<vector<int>>& right, vector<vector<int>>& result)
{
    for (size_t i = 0; i < left.size(); ++i)
    {
        for (size_t j = 0; j < right[0].size(); ++j)
        {
            for (size_t k = 0; k < right.size(); ++k)
                result[i][j] += left[i][k] * right[k][j];
        }
    }
}

int main()
{
    for (int NA = 1; NA <= 10; ++NA)
    {
        for (int MB = 1; MB <= 10; ++MB)
        {
            for (int MA = 1; MA <= 10; ++MA)
            {
                int NB = MA;
                vector<int> testA(NA * MA, 0);
                for (int i = 0; i < NA * MA; ++i)
                    testA[i] = i + 1;
                shuffle(testA.begin(), testA.end(), default_random_engine());
                vector<vector<int>> A(NA, vector<int>(MA, 0));

                size_t k = 0;
                for (size_t i = 0; i < A.size(); ++i)
                {
                    for (size_t j = 0; j < A[i].size(); ++j)
                    {
                        A[i][j] = testA[k++];
                    }
                }

                vector<int> testB(NB * MB, 0);
                for (int i = 0; i < NB * MB; ++i)
                    testB[i] = i + 1;
                shuffle(testB.begin(), testB.end(), default_random_engine());
                vector<vector<int>> B(NB, vector<int>(MB, 0));

                k = 0;
                for (size_t i = 0; i < B.size(); ++i)
                {
                    for (size_t j = 0; j < B[i].size(); ++j)
                    {
                        B[i][j] = testB[k++];
                    }
                }

                cout << "相乘的两个矩阵为" << endl;
                for (size_t i = 0; i < A.size(); ++i)
                {
                    for (size_t j = 0; j < A[i].size(); ++j)
                    {
                        cout << A[i][j] << " ";
                    }
                    cout << endl;
                }
                cout << endl;

                for (size_t i = 0; i < B.size(); ++i)
                {
                    for (size_t j = 0; j < B[i].size(); ++j)
                    {
                        cout << B[i][j] << " ";
                    }
                    cout << endl;
                }

                vector<vector<int>> result1(NA, vector<int>(MB, 0));
                vector<vector<int>> result2(NA, vector<int>(MB, 0));
                vector<vector<int>> _B = B;
                vector<vector<int>> _A = A;
                Strassen(A, B, result1);
                matrixMultiply(_A, _B, result2);
                for (size_t i = 0; i < result1.size(); ++i)
                {
                    for (size_t j = 0; j < result1[i].size(); ++j)
                    {
                        if (result1[i][j] != result2[i][j])
                        {
                            cout << "矩阵相乘结果错误!" << endl;
                            exit(-1);
                        }
                    }
                }
                cout << "矩阵相乘结果正确,乘积为" << endl;
                for (size_t i = 0; i < result1.size(); ++i)
                {
                    for (size_t j = 0; j < result1[i].size(); ++j)
                    {
                        cout << result1[i][j] << " ";
                    }
                    cout << endl;
                }
            }
        }
    }

   /* const int NA = 3;
    const int MA = 4;
    const int NB = 4;
    const int MB = 5;
    vector<int> testA(NA * MA, 0);
    for (int i = 0; i < NA * MA; ++i)
        testA[i] = i + 1;
    shuffle(testA.begin(), testA.end(), default_random_engine());
    vector<vector<int>> A(NA, vector<int>(MA, 0));

    size_t k = 0;
    for (size_t i = 0; i < A.size(); ++i)
    {
        for (size_t j = 0; j < A[i].size(); ++j)
        {
            A[i][j] = testA[k++];
        }
    }

    vector<int> testB(NB * MB, 0);
    for (int i = 0; i < NB * MB; ++i)
        testB[i] = i + 1;
    shuffle(testB.begin(), testB.end(), default_random_engine());
    vector<vector<int>> B(NB, vector<int>(MB, 0));

    k = 0;
    for (size_t i = 0; i < B.size(); ++i)
    {
        for (size_t j = 0; j < B[i].size(); ++j)
        {
            B[i][j] = testB[k++];
        }
    }

    cout << "相乘的两个矩阵为" << endl;
    for (size_t i = 0; i < A.size(); ++i)
    {
        for (size_t j = 0; j < A[i].size(); ++j)
        {
            cout << A[i][j] << " ";
        }
        cout << endl;
    }
    cout << endl;

    for (size_t i = 0; i < B.size(); ++i)
    {
        for (size_t j = 0; j < B[i].size(); ++j)
        {
            cout << B[i][j] << " ";
        }
        cout << endl;
    }

    vector<vector<int>> result1(NA, vector<int>(MB, 0));
    vector<vector<int>> result2(NA, vector<int>(MB, 0));
    vector<vector<int>> _B = B;
    vector<vector<int>> _A = A;
    Strassen(A, B, result1);
    matrixMultiply(_A, _B, result2);
    for (size_t i = 0; i < result1.size(); ++i)
    {
        for (size_t j = 0; j < result1[i].size(); ++j)
        {
            if (result1[i][j] != result2[i][j])
            {
                cout << "矩阵相乘结果错误!" << endl;
                exit(-1);
            }
        }
    }
    cout << "矩阵相乘结果正确,乘积为" << endl;
    for (size_t i = 0; i < result1.size(); ++i)
    {
        for (size_t j = 0; j < result1[i].size(); ++j)
        {
            cout << result1[i][j] << " ";
        }
        cout << endl;
    }*/
    return 0;
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值