c++ 分而治之(施特拉森矩阵乘法)

给定两个大小分别为 nxn 的方阵 A 和 B,求它们的乘法矩阵。 
朴素方法:以下是两个矩阵相乘的简单方法。
 
void multiply(int A[][N], int B[][N], int C[][N])
{
    for (int i = 0; i < N; i++)
    {
        for (int j = 0; j < N; j++)
        {
            C[i][j] = 0;
            for (int k = 0; k < N; k++)
            {
                C[i][j] += A[i][k]*B[k][j];
            }
        }
    }
}
 
// This code is contributed by noob2000.

上述方法的时间复杂度为O(N 3 )。 

分而治之 :
以下是两个方阵相乘的简单分而治之方法。 
1、将矩阵 A 和 B 分为 4 个大小为 N/2 x N/2 的子矩阵,如下图所示。 
2、递归计算以下值。 ae + bg、af + bh、ce + dg 和 cf + dh。 

执行:
#include <bits/stdc++.h>
using namespace std;
 
#define ROW_1 4
#define COL_1 4
 
#define ROW_2 4
#define COL_2 4
 
void print(string display, vector<vector<int> > matrix,
           int start_row, int start_column, int end_row,
           int end_column)
{
    cout << endl << display << " =>" << endl;
    for (int i = start_row; i <= end_row; i++) {
        for (int j = start_column; j <= end_column; j++) {
            cout << setw(10);
            cout << matrix[i][j];
        }
        cout << endl;
    }
    cout << endl;
    return;
}
 
void add_matrix(vector<vector<int> > matrix_A,
                vector<vector<int> > matrix_B,
                vector<vector<int> >& matrix_C,
                int split_index)
{
    for (auto i = 0; i < split_index; i++)
        for (auto j = 0; j < split_index; j++)
            matrix_C[i][j]
                = matrix_A[i][j] + matrix_B[i][j];
}
 
vector<vector<int> >
multiply_matrix(vector<vector<int> > matrix_A,
                vector<vector<int> > matrix_B)
{
    int col_1 = matrix_A[0].size();
    int row_1 = matrix_A.size();
    int col_2 = matrix_B[0].size();
    int row_2 = matrix_B.size();
 
    if (col_1 != row_2) {
        cout << "\nError: The number of columns in Matrix "
                "A  must be equal to the number of rows in "
                "Matrix B\n";
        return {};
    }
 
    vector<int> result_matrix_row(col_2, 0);
    vector<vector<int> > result_matrix(row_1,
                                       result_matrix_row);
 
    if (col_1 == 1)
        result_matrix[0][0]
            = matrix_A[0][0] * matrix_B[0][0];
    else {
        int split_index = col_1 / 2;
 
        vector<int> row_vector(split_index, 0);
        vector<vector<int> > result_matrix_00(split_index,
                                              row_vector);
        vector<vector<int> > result_matrix_01(split_index,
                                              row_vector);
        vector<vector<int> > result_matrix_10(split_index,
                                              row_vector);
        vector<vector<int> > result_matrix_11(split_index,
                                              row_vector);
 
        vector<vector<int> > a00(split_index, row_vector);
        vector<vector<int> > a01(split_index, row_vector);
        vector<vector<int> > a10(split_index, row_vector);
        vector<vector<int> > a11(split_index, row_vector);
        vector<vector<int> > b00(split_index, row_vector);
        vector<vector<int> > b01(split_index, row_vector);
        vector<vector<int> > b10(split_index, row_vector);
        vector<vector<int> > b11(split_index, row_vector);
 
        for (auto i = 0; i < split_index; i++)
            for (auto j = 0; j < split_index; j++) {
                a00[i][j] = matrix_A[i][j];
                a01[i][j] = matrix_A[i][j + split_index];
                a10[i][j] = matrix_A[split_index + i][j];
                a11[i][j] = matrix_A[i + split_index]
                                    [j + split_index];
                b00[i][j] = matrix_B[i][j];
                b01[i][j] = matrix_B[i][j + split_index];
                b10[i][j] = matrix_B[split_index + i][j];
                b11[i][j] = matrix_B[i + split_index]
                                    [j + split_index];
            }
 
        add_matrix(multiply_matrix(a00, b00),
                   multiply_matrix(a01, b10),
                   result_matrix_00, split_index);
        add_matrix(multiply_matrix(a00, b01),
                   multiply_matrix(a01, b11),
                   result_matrix_01, split_index);
        add_matrix(multiply_matrix(a10, b00),
                   multiply_matrix(a11, b10),
                   result_matrix_10, split_index);
        add_matrix(multiply_matrix(a10, b01),
                   multiply_matrix(a11, b11),
                   result_matrix_11, split_index);
 
        for (auto i = 0; i < split_index; i++)
            for (auto j = 0; j < split_index; j++) {
                result_matrix[i][j]
                    = result_matrix_00[i][j];
                result_matrix[i][j + split_index]
                    = result_matrix_01[i][j];
                result_matrix[split_index + i][j]
                    = result_matrix_10[i][j];
                result_matrix[i + split_index]
                             [j + split_index]
                    = result_matrix_11[i][j];
            }
 
        result_matrix_00.clear();
        result_matrix_01.clear();
        result_matrix_10.clear();
        result_matrix_11.clear();
        a00.clear();
        a01.clear();
        a10.clear();
        a11.clear();
        b00.clear();
        b01.clear();
        b10.clear();
        b11.clear();
    }
    return result_matrix;
}
 
int main()
{
    vector<vector<int> > matrix_A = { { 1, 1, 1, 1 },
                                      { 2, 2, 2, 2 },
                                      { 3, 3, 3, 3 },
                                      { 2, 2, 2, 2 } };
 
    print("Array A", matrix_A, 0, 0, ROW_1 - 1, COL_1 - 1);
 
    vector<vector<int> > matrix_B = { { 1, 1, 1, 1 },
                                      { 2, 2, 2, 2 },
                                      { 3, 3, 3, 3 },
                                      { 2, 2, 2, 2 } };
 
    print("Array B", matrix_B, 0, 0, ROW_2 - 1, COL_2 - 1);
 
    vector<vector<int> > result_matrix(
        multiply_matrix(matrix_A, matrix_B));
 
    print("Result Array", result_matrix, 0, 0, ROW_1 - 1,
          COL_2 - 1);
}
 
// Time Complexity: O(n^3)
// Code Contributed By: lucasletum

输出
数组A =>
         1 1 1 1
         2 2 2 2
         3 3 3 3
         2 2 2 2


数组 B =>
         1 1 1 1
         2 2 2 2
         3 3 3 3
         2 2 2 2


结果数组=>
         8 8 8 8
        16 16 16 16
        24 24 24 24
        16 16 16 16
        
在上述方法中,我们对大小为 N/2 x N/2 的矩阵进行 8 次乘法和 4 次加法。两个矩阵相加需要 O(N 2 ) 时间。所以时间复杂度可以写成 

T(N) = 8T(N/2) + O(N 2 )  

根据马斯特定理,上述方法的时间复杂度为 O(N 3 )
不幸的是,这与上面的简单方法相同。

简单的分而治之也导致O(N 3 ),有更好的方法吗? 

        在上面的分而治之的方法中,高时间复杂度的主要成分是8次递归调用。Strassen 方法的思想是将递归调用次数减少到 7 次。Strassen 方法与上述简单的分而治之方法类似,该方法也将矩阵划分为大小为 N/2 x N/2 的子矩阵:如上图所示,但在Strassen方法中,结果的四个子矩阵是使用以下公式计算的。

Strassen 方法的时间复杂度

两个矩阵的加法和减法需要 O(N 2 ) 时间。所以时间复杂度可以写成 

T(N) = 7T(N/2) + O(N 2 )

根据马斯特定理,上述方法的时间复杂度为
O(N Log7 ) 大约为 O(N 2.8074 )

一般来说,由于以下原因,施特拉森方法在实际应用中并不优选。 

1、Strassen 方法中使用的常数很高,对于典型应用,Naive 方法效果更好。 
2、对于稀疏矩阵,有专门为其设计的更好的方法。 
3、递归中的子矩阵占用额外的空间。 
4、由于计算机对非整数值的运算精度有限,Strassen 算法中累积的误差比 Naive 方法中更大。

执行:
#include <bits/stdc++.h>
using namespace std;
 
#define ROW_1 4
#define COL_1 4
 
#define ROW_2 4
#define COL_2 4
 
void print(string display, vector<vector<int> > matrix,
           int start_row, int start_column, int end_row,
           int end_column)
{
    cout << endl << display << " =>" << endl;
    for (int i = start_row; i <= end_row; i++) {
        for (int j = start_column; j <= end_column; j++) {
            cout << setw(10);
            cout << matrix[i][j];
        }
        cout << endl;
    }
    cout << endl;
    return;
}
 
vector<vector<int> >
add_matrix(vector<vector<int> > matrix_A,
           vector<vector<int> > matrix_B, int split_index,
           int multiplier = 1)
{
    for (auto i = 0; i < split_index; i++)
        for (auto j = 0; j < split_index; j++)
            matrix_A[i][j]
                = matrix_A[i][j]
                  + (multiplier * matrix_B[i][j]);
    return matrix_A;
}
 
vector<vector<int> >
multiply_matrix(vector<vector<int> > matrix_A,
                vector<vector<int> > matrix_B)
{
    int col_1 = matrix_A[0].size();
    int row_1 = matrix_A.size();
    int col_2 = matrix_B[0].size();
    int row_2 = matrix_B.size();
 
    if (col_1 != row_2) {
        cout << "\nError: The number of columns in Matrix "
                "A  must be equal to the number of rows in "
                "Matrix B\n";
        return {};
    }
 
    vector<int> result_matrix_row(col_2, 0);
    vector<vector<int> > result_matrix(row_1,
                                       result_matrix_row);
 
    if (col_1 == 1)
        result_matrix[0][0]
            = matrix_A[0][0] * matrix_B[0][0];
    else {
        int split_index = col_1 / 2;
 
        vector<int> row_vector(split_index, 0);
 
        vector<vector<int> > a00(split_index, row_vector);
        vector<vector<int> > a01(split_index, row_vector);
        vector<vector<int> > a10(split_index, row_vector);
        vector<vector<int> > a11(split_index, row_vector);
        vector<vector<int> > b00(split_index, row_vector);
        vector<vector<int> > b01(split_index, row_vector);
        vector<vector<int> > b10(split_index, row_vector);
        vector<vector<int> > b11(split_index, row_vector);
 
        for (auto i = 0; i < split_index; i++)
            for (auto j = 0; j < split_index; j++) {
                a00[i][j] = matrix_A[i][j];
                a01[i][j] = matrix_A[i][j + split_index];
                a10[i][j] = matrix_A[split_index + i][j];
                a11[i][j] = matrix_A[i + split_index]
                                    [j + split_index];
                b00[i][j] = matrix_B[i][j];
                b01[i][j] = matrix_B[i][j + split_index];
                b10[i][j] = matrix_B[split_index + i][j];
                b11[i][j] = matrix_B[i + split_index]
                                    [j + split_index];
            }
 
        vector<vector<int> > p(multiply_matrix(
            a00, add_matrix(b01, b11, split_index, -1)));
        vector<vector<int> > q(multiply_matrix(
            add_matrix(a00, a01, split_index), b11));
        vector<vector<int> > r(multiply_matrix(
            add_matrix(a10, a11, split_index), b00));
        vector<vector<int> > s(multiply_matrix(
            a11, add_matrix(b10, b00, split_index, -1)));
        vector<vector<int> > t(multiply_matrix(
            add_matrix(a00, a11, split_index),
            add_matrix(b00, b11, split_index)));
        vector<vector<int> > u(multiply_matrix(
            add_matrix(a01, a11, split_index, -1),
            add_matrix(b10, b11, split_index)));
        vector<vector<int> > v(multiply_matrix(
            add_matrix(a00, a10, split_index, -1),
            add_matrix(b00, b01, split_index)));
 
        vector<vector<int> > result_matrix_00(add_matrix(
            add_matrix(add_matrix(t, s, split_index), u,
                       split_index),
            q, split_index, -1));
        vector<vector<int> > result_matrix_01(
            add_matrix(p, q, split_index));
        vector<vector<int> > result_matrix_10(
            add_matrix(r, s, split_index));
        vector<vector<int> > result_matrix_11(add_matrix(
            add_matrix(add_matrix(t, p, split_index), r,
                       split_index, -1),
            v, split_index, -1));
 
        for (auto i = 0; i < split_index; i++)
            for (auto j = 0; j < split_index; j++) {
                result_matrix[i][j]
                    = result_matrix_00[i][j];
                result_matrix[i][j + split_index]
                    = result_matrix_01[i][j];
                result_matrix[split_index + i][j]
                    = result_matrix_10[i][j];
                result_matrix[i + split_index]
                             [j + split_index]
                    = result_matrix_11[i][j];
            }
 
        a00.clear();
        a01.clear();
        a10.clear();
        a11.clear();
        b00.clear();
        b01.clear();
        b10.clear();
        b11.clear();
        p.clear();
        q.clear();
        r.clear();
        s.clear();
        t.clear();
        u.clear();
        v.clear();
        result_matrix_00.clear();
        result_matrix_01.clear();
        result_matrix_10.clear();
        result_matrix_11.clear();
    }
    return result_matrix;
}
 
int main()
{
    vector<vector<int> > matrix_A = { { 1, 1, 1, 1 },
                                      { 2, 2, 2, 2 },
                                      { 3, 3, 3, 3 },
                                      { 2, 2, 2, 2 } };
 
    print("Array A", matrix_A, 0, 0, ROW_1 - 1, COL_1 - 1);
 
    vector<vector<int> > matrix_B = { { 1, 1, 1, 1 },
                                      { 2, 2, 2, 2 },
                                      { 3, 3, 3, 3 },
                                      { 2, 2, 2, 2 } };
 
    print("Array B", matrix_B, 0, 0, ROW_2 - 1, COL_2 - 1);
 
    vector<vector<int> > result_matrix(
        multiply_matrix(matrix_A, matrix_B));
 
    print("Result Array", result_matrix, 0, 0, ROW_1 - 1,
          COL_2 - 1);
}
 
// Time Complexity: T(N) = 7T(N/2) +  O(N^2) => O(N^Log7)
// which is approximately O(N^2.8074) Code Contributed By:
// lucasletum

输出
数组A =>
         1 1 1 1
         2 2 2 2
         3 3 3 3
         2 2 2 2


数组 B =>
         1 1 1 1
         2 2 2 2
         3 3 3 3
         2 2 2 2


结果数组=>
         8 8 8 8
        16 16 16 16
        24 24 24 24
        16 16 16 16 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值