Strassen’s 矩阵乘法—分治法实现

内容会持续更新,有错误的地方欢迎指正,谢谢!

前言:博主最近正在学习《算法》这门专业课程,这是该课程的第二次上机题目,我把自己的解题方法分享给大家,欢迎讨论!

题目:
1.比较数学定义的矩阵乘法算法和Strassen’s 矩阵乘法算法的效率;
2.自主生成两个16*16的矩阵,输出Strassen’s 矩阵乘法算法结果。

数学定义的矩阵乘法算法:利用三个for循环来解决,时间复杂度为O(n^3)。

数学定义的矩阵乘法算法的核心代码如下:

//公理:两个矩阵相乘A*B,A的列数必等于B的行数。
    int a[2][3] = {1, 1, 1, 1, 1, 1};  
    int b[3][1] = {1, 1, 1};  
    for (int i = 0; i < 2; ++i)  
    {  
        for (int j = 0; j < 1; ++j)  
        {  
            c[i][j] = 0;  
            for (int k = 0; k < 3; ++k)  
                c[i][j] += a[i][k] * b[k][j];  
        }  
    }  

一般算法需要八次乘法:

这里写图片描述

试试Strassen’s 矩阵乘法算法:

这里写图片描述

我们可以推出:

这里写图片描述

上面只有7次乘法和多次加减法,Strassen’s 矩阵乘法算法将其变成7次乘法。大家都知道乘法比加减法消耗更多的性能!所以,该算法能将时间复杂度降低到O( n^lg7 ) = O( n^2.81 )。

代码实现如下:(其中N必须为2的幂,这里N=16)

#include <iostream>
using namespace std;
#define N 16

//矩阵相加
void Plus(int a[N / 2][N / 2], int b[N / 2][N / 2], int c[N / 2][N / 2])
{
    int i, j;
    for (i = 0; i < N / 2; i++)
    {
        for (j = 0; j < N / 2; j++)
        {
            a[i][j] = b[i][j] + c[i][j];
        }
    }
}

//矩阵相减
void Minus(int a[N / 2][N / 2], int b[N / 2][N / 2], int c[N / 2][N / 2])
{
    int i, j;
    for (i = 0; i < N / 2; i++)
    {
        for (j = 0; j < N / 2; j++)
        {
            a[i][j] = b[i][j] - c[i][j];
        }
    }
}

//矩阵相乘
void Multiply(int a[N / 2][N / 2], int b[N / 2][N / 2], int c[N / 2][N / 2])
{
    int i, j, k;
    for (i = 0; i < N / 2; i++)
    {
        for (j = 0; j < N / 2; j++)
        {
            a[i][j] = 0;
            for (k = 0; k < N / 2; k++)
            {
                a[i][j] += b[i][k] * c[k][j];
            }
        }
    }
}

int main()
{
    int i, j, k;
    int m1[N][N];
    int m2[N][N];
    for (i = 0; i < N; ++i)//初始化要相乘的这两个16*16的矩阵
    {
        for (j = 0; j < N; ++j)
        {
            m1[i][j] = 1;
            m2[i][j] = 1;
        }
    }

    int I[N / 2][N / 2], J[N / 2][N / 2], K[N / 2][N / 2], L[N / 2][N / 2];
    int A[N / 2][N / 2], B[N / 2][N / 2], C[N / 2][N / 2], D[N / 2][N / 2];
    int E[N / 2][N / 2], F[N / 2][N / 2], G[N / 2][N / 2], H[N / 2][N / 2];
    int S1[N / 2][N / 2], S2[N / 2][N / 2], S3[N / 2][N / 2], S4[N / 2][N / 2];
    int S5[N / 2][N / 2], S6[N / 2][N / 2], S7[N / 2][N / 2];
    int t1[N / 2][N / 2], t2[N / 2][N / 2];

    //将原矩阵m1、m2拆分为A B C D E F G H矩阵
    for (i = 0; i < N / 2; i++)
    {
        for (j = 0; j < N / 2; j++)
        {
            A[i][j] = m1[i][j];
            B[i][j] = m1[i][j + N / 2];
            C[i][j] = m1[i + N / 2][j];
            D[i][j] = m1[i + N / 2][j + N / 2];

            E[i][j] = m2[i][j];
            F[i][j] = m2[i][j + N / 2];
            G[i][j] = m2[i + N / 2][j];
            H[i][j] = m2[i + N / 2][j + N / 2];
        }
    }

    //S1  
    Minus(I, F, H);
    Multiply(S1, A, I);

    //S2  
    Plus(I, A, B);
    Multiply(S2, I, H);

    //S3  
    Plus(I, C, D);
    Multiply(S3, I, E);

    //S4  
    Minus(I, G, E);
    Multiply(S4, D, I);

    //S5  
    Plus(I, A, D);
    Plus(J, E, F);
    Multiply(S5, I, J);

    //S6  
    Minus(I, B, D);
    Plus(J, G, H);
    Multiply(S6, I, J);

    //S7  
    Minus(I, A, C);
    Plus(J, E, F);
    Multiply(S7, I, J);

    //计算I J K L矩阵
    //I = S5 + S4 - S2 + S6  
    Plus(t1, S5, S4);
    Minus(t2, t1, S2);
    Plus(I, t2, S6);

    //J = S1 + S2  
    Plus(J, S1, S2);

    //K = S3 + S4  
    Plus(K, S3, S4);

    //L = S5 + S1 - S3 - S7 = S5 + S1 - ( S3 + S7 )  
    Plus(t1, S5, S1);
    Plus(t2, S3, S7);
    Minus(L, t1, t2);

    //将得到的I J K L矩阵合并到最终结果result矩阵中
    int result[N][N] = { 0 };
    for (int i = 0; i < N / 2; i++)
    {
        for (int j = 0; j < N / 2; j++)
        {
            result[i][j] = I[i][j];
            result[i][j + N / 2] = J[i][j];
            result[i + N / 2][j] = K[i][j];
            result[i + N / 2][j + N / 2] = L[i][j];
        }
    }

    //输出最终的矩阵
    for (i = 0; i < N; ++i)
    {
        k = 0;
        for (j = 0; j < N; ++j)
        {
            cout << result[i][j] << "  ";
            ++k;
            if (k == N)
                cout << endl;
        }
    }
    cout << endl;

    getchar();
    return 0;
}

备注:由于博主时间问题,本代码并未实现递归,也就是并未利用分治法拆分到最小单元再计算再合并,只是阐述了分治法解决该问题的思路,若要实现完整版代码,我指明方法:
新建一个递归函数,需将main()里的部分代码移到递归函数里,并需修改递归函数里的所有二维数组的定义,例如:

int MatrixA[N / 2][N / 2];
int MatrixB[N / 2][N / 2];
int MatrixC[N / 2][N / 2];

应该被修改为如下形式:

//n为递归函数传入的参数
int** MatrixA = new int*[n];
int** MatrixB = new int*[n];
int** MatrixC = new int*[n];
for (int i = 0; i < n; i++)
{
    MatrixA[i] = new int[n];
    MatrixB[i] = new int[n];
    MatrixC[i] = new int[n];
}

用完new的二维数组之后还要记得释放内存,不然,在递归中,很容易产生内存泄漏:

for (int i = 0; i < n; i++)
{
    delete[] A[i];
    delete[] B[i];
}
delete[] A;
delete[] B;

递归函数的参数有n,MatrixA,MatrixB,MatrixC
n用于传递矩阵维数。
MatrixA矩阵就是上方代码的m1矩阵。该题是求m1乘以m2矩阵,你就知道m1是什么了。
MatrixB矩阵就是上方代码的m2矩阵。该题是求m1乘以m2矩阵,你就知道m2是什么了。
MatrixC矩阵用于记录结果,最后输出MatrixC即是最终结果。

分治法实现的完整代码,能输出最终结果和每一次递归的S1~S7:
http://download.csdn.net/download/billcyj/10157466

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值