分治策略之矩阵乘法的几种实现

欢迎关注,定期更新算法问题

今天介绍一下分治算法的一个典型例子——矩阵乘法

如果以前了解过矩阵,应该知道矩阵的乘法公式C(m,n)=A(m,k)*B(k,n),在这里我们只讨论方阵,假设A是n*n阶,B也是n*n阶,那么要计算乘积需要进行n^2个元素,看《算法导论》给出的伪码:



上边这个过程进行三次循环,每次循环执行n步,故花费时间是O(n^3),下边是代码实现:

void Squre_Multiplay(int A[][n],int B[][n],int C[][n])
{
    int i,j;
    for(i=0;i<n;i++)
    {
        for(j=0;j<n;j++)
        {
            C[i][j]=0;
            for(int k=0;k<n;k++)
                C[i][j]=C[i][j]+A[i][k]*B[k][j];
        }
    }
}

当然我们学习了分治算法后可以考虑能否利用递归来降低时间复杂度,先来看一个简单的分治算法:

假设A,B,C都是n*n阶,n是2的幂,乘法公式C=A*B,我们可以将这三个矩阵分别化成4个(n/2)*(n/2)阶的矩阵,即


然后可以将乘法公式写成4个递归式:


有了递归式就可以设计算法了,思路即是进行分解,求解,合并这几个步骤,将大矩阵化为4个小矩阵直到足够小,然后进行简单的二阶矩阵的计算,以下是《算法导论》上的伪码:


这段伪码理解起来很简单,但是实现起来比较复杂,因为它隐藏了一个重要的细节,即如何划分矩阵的问题,我们常规做法可能是新建几个新的矩阵,然后在从原矩阵特定位置赋值过来,但这样实现起来复杂并且很容易出错,这里本人利用下标来进行划分和计算,代码看起来清晰:

void Squre_Multiplay_recursive(int A[][n],int B[][n],int C[][n],int A_flag[],int B_flag[])
{
    if(N==2)
    {
        C[A_flag[2]][B_flag[0]]=A[A_flag[2]][A_flag[0]]*B[B_flag[2]][B_flag[0]]+A[A_flag[2]][A_flag[1]]*B[B_flag[3]][B_flag[0]]+C[A_flag[2]][B_flag[0]];
        C[A_flag[2]][B_flag[1]]=A[A_flag[2]][A_flag[0]]*B[B_flag[2]][B_flag[1]]+A[A_flag[2]][A_flag[1]]*B[B_flag[3]][B_flag[1]]+C[A_flag[2]][B_flag[1]];
        C[A_flag[3]][B_flag[0]]=A[A_flag[3]][A_flag[0]]*B[B_flag[2]][B_flag[0]]+A[A_flag[3]][A_flag[1]]*B[B_flag[3]][B_flag[0]]+C[A_flag[3]][B_flag[0]];
        C[A_flag[3]][B_flag[1]]=A[A_flag[3]][A_flag[0]]*B[B_flag[2]][B_flag[1]]+A[A_flag[3]][A_flag[1]]*B[B_flag[3]][B_flag[1]]+C[A_flag[3]][B_flag[1]];
    }
       else
    {
       // int one[4],two[4],three[4],four[4];
       N=N/2;
        cout<<"N value:"<<N<<endl;
        int one_A[4],two_A[4],three_A[4],four_A[4];
        int one_B[4],two_B[4],three_B[4],four_B[4];
        //
        divide_array(A,A_flag,one_A,two_A,three_A,four_A);
        divide_array(B,B_flag,one_B,two_B,three_B,four_B);
        //
        Squre_Multiplay_recursive(A,B,C,one_A,one_B);
        Squre_Multiplay_recursive(A,B,C,two_A,three_B);
        //
        Squre_Multiplay_recursive(A,B,C,one_A,two_B);
        Squre_Multiplay_recursive(A,B,C,two_A,four_B);
        //
        Squre_Multiplay_recursive(A,B,C,three_A,one_B);
        Squre_Multiplay_recursive(A,B,C,four_A,three_B);
        //
        Squre_Multiplay_recursive(A,B,C,three_A,two_B);
        Squre_Multiplay_recursive(A,B,C,four_A,four_B);
    }
}
void divide_array(int Array[][n],int flag[],int one[],int two[],int three[],int four[])
{
    int left=flag[0];
    int right=flag[1];
    int top=flag[2];
    int bottom=flag[3];
        one[0]=left;
        one[1]=(left+right)/2;
         one[2]=top;
         one[3]=(top+bottom)/2;

        //
         two[0]=(left+right+1)/2;
         two[1]=right;
         two[2]=top;
         two[3]=(top+bottom)/2;

        //
         three[0]=left;
         three[1]=(left+right)/2;
         three[2]=(top+bottom+1)/2;
         three[3]=bottom;

        //
         four[0]=(left+right+1)/2;
         four[1]=right;
         four[2]=(top+bottom+1)/2;
         four[3]=bottom;
for(int i=0;i<4;i++)
    cout<<"划分边界:"<<one[i]<<" "<<two[i]<<" "<<three[i]<<" "<<four[i]<<endl;
}
第一个函数是递归函数,内部层次很明确,当矩阵规模大于2时,进行划分4个子矩阵,然后每个子矩阵递归调用,当规模为2阶时,直接利用公式计算。第二个函数是划分函数,为了容易理解和避免引入指针,我们用表示矩形的方式表示矩阵,即表示出来矩阵的左、右、上、下位置保存到位置数组中。
上述算法运行时间的递归式:


通过求解,可以看出T(n)=O(n^3),与一般算法比较,并没有任何提高,反而增加了递归带来的开销。

下一篇介绍Strassen算法。

此篇全部源码如下:

#include <iostream>
#define n 4
int N=n;
using namespace std;
void divide_array(int Array[][n],int flag[],int one[],int two[],int three[],int four[]);
//一般方法,时间复杂度为O(n^3),
void Squre_Multiplay(int A[][n],int B[][n],int C[][n]);
//普通分治算法,时间复杂度为O(n^2);
void Squre_Multiplay_recursive(int A[][n],int B[][n],int C[][n],int A_flag[],int B_flag[]);
//Stassen算法
int main()
{
    int A[4][4]={1,2,3,4,1,2,3,4,1,2,3,4,1,2,3,4};
    int B[4][4]={1,2,3,4,1,2,3,4,1,2,3,4,1,2,3,4};
    int C[4][4]={0};
    int A_flag[4]={0,3,0,3};
    int B_flag[4]={0,3,0,3};
   //Squre_Multiplay(A,B,C);
  Squre_Multiplay_recursive(A,B,C,A_flag,B_flag);
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
            cout<<C[i][j]<<"\t";
        cout<<endl;
    }
    return 0;
}

void Squre_Multiplay(int A[][n],int B[][n],int C[][n])
{
    int i,j;
    for(i=0;i<n;i++)
    {
        for(j=0;j<n;j++)
        {
            C[i][j]=0;
            for(int k=0;k<n;k++)
                C[i][j]=C[i][j]+A[i][k]*B[k][j];
        }
    }
}

void Squre_Multiplay_recursive(int A[][n],int B[][n],int C[][n],int A_flag[],int B_flag[])
{
    if(N==2)
    {
        C[A_flag[2]][B_flag[0]]=A[A_flag[2]][A_flag[0]]*B[B_flag[2]][B_flag[0]]+A[A_flag[2]][A_flag[1]]*B[B_flag[3]][B_flag[0]]+C[A_flag[2]][B_flag[0]];
        C[A_flag[2]][B_flag[1]]=A[A_flag[2]][A_flag[0]]*B[B_flag[2]][B_flag[1]]+A[A_flag[2]][A_flag[1]]*B[B_flag[3]][B_flag[1]]+C[A_flag[2]][B_flag[1]];
        C[A_flag[3]][B_flag[0]]=A[A_flag[3]][A_flag[0]]*B[B_flag[2]][B_flag[0]]+A[A_flag[3]][A_flag[1]]*B[B_flag[3]][B_flag[0]]+C[A_flag[3]][B_flag[0]];
        C[A_flag[3]][B_flag[1]]=A[A_flag[3]][A_flag[0]]*B[B_flag[2]][B_flag[1]]+A[A_flag[3]][A_flag[1]]*B[B_flag[3]][B_flag[1]]+C[A_flag[3]][B_flag[1]];
    }
       else
    {
       // int one[4],two[4],three[4],four[4];
       N=N/2;
        cout<<"N value:"<<N<<endl;
        int one_A[4],two_A[4],three_A[4],four_A[4];
        int one_B[4],two_B[4],three_B[4],four_B[4];
        //
        divide_array(A,A_flag,one_A,two_A,three_A,four_A);
        divide_array(B,B_flag,one_B,two_B,three_B,four_B);
        //
        Squre_Multiplay_recursive(A,B,C,one_A,one_B);
        Squre_Multiplay_recursive(A,B,C,two_A,three_B);
        //
        Squre_Multiplay_recursive(A,B,C,one_A,two_B);
        Squre_Multiplay_recursive(A,B,C,two_A,four_B);
        //
        Squre_Multiplay_recursive(A,B,C,three_A,one_B);
        Squre_Multiplay_recursive(A,B,C,four_A,three_B);
        //
        Squre_Multiplay_recursive(A,B,C,three_A,two_B);
        Squre_Multiplay_recursive(A,B,C,four_A,four_B);
    }
}
void divide_array(int Array[][n],int flag[],int one[],int two[],int three[],int four[])
{
    int left=flag[0];
    int right=flag[1];
    int top=flag[2];
    int bottom=flag[3];
        one[0]=left;
        one[1]=(left+right)/2;
         one[2]=top;
         one[3]=(top+bottom)/2;

        //
         two[0]=(left+right+1)/2;
         two[1]=right;
         two[2]=top;
         two[3]=(top+bottom)/2;

        //
         three[0]=left;
         three[1]=(left+right)/2;
         three[2]=(top+bottom+1)/2;
         three[3]=bottom;

        //
         four[0]=(left+right+1)/2;
         four[1]=right;
         four[2]=(top+bottom+1)/2;
         four[3]=bottom;
for(int i=0;i<4;i++)
    cout<<"划分边界:"<<one[i]<<" "<<two[i]<<" "<<three[i]<<" "<<four[i]<<endl;
}


  • 7
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值