第4章 分治策略——strassen算法

strassen算法首先要求两个矩阵要是方块矩阵,并且维数是2的幂。如果两个矩阵不满足这两个条件,可以增加元素均为0的行与列以满足条件。假设要计算C=A*B。

  1. 首先将A,B,C均分解为4个n/2*n/2的子矩阵。假设将A分解为A11,A12,A21,A22。 B分解为B11,B12, B21, B22。 对C做类似的操作。

  2. 创建10个n/2*n/2的矩阵 S1,S2,…, S10。

    • S1=B12-B22
    • S2=A11+A12
    • S3=A21+A22
    • S4=B21-B22
    • S5=A11+A22
    • S6=B11+B22
    • S7=A12-A22
    • S8=B21+B22
    • S9=A11-A21
    • S10=B11+B12
  3. 递归地计算7次n/2*n/2矩阵的乘法

    • P1=A11*S1
    • P2=S2*B22
    • P3=S3*B11
    • P4=A22*S4
    • P5=S5*S6
    • P6=S7*S8
    • P7=S9*S10
  4. 对步骤3创建的7个矩阵进行加减运算,计算出C的4个子矩阵

    • C11=P5+P4-P2+P6
    • C12=P1+P2
    • C21=P3+P4
    • C22=P5+P1-P3-P7

    代码如下:

#include <vector>
#include <stdexcept>
#include <cmath>

template<class Type>
void subMatrices(int lowRow,int highRow,int lowColumn,int highColumn,const vector< vector<Type> >& A,vector <vector<Type> >& subMatr)
{
        subMatr.resize(highRow-lowRow);
        for(int i=0;i!=subMatr.size();++i)
                subMatr[i].resize(highColumn-lowColumn,Type());

        for(int i=0;i!=subMatr.size();++i)
                for(int j=0;j!=subMatr[i].size();++j)
                        subMatr[i][j]=A[lowRow+i][lowColumn+j];
}

template<class Type>
void matrixSubtract(const vector< vector<Type> >& A, const vector< vector<Type> >& B, vector< vector<Type> >& C)
{
        if(A.size()==0||B.size()==0)
                throw runtime_error("one of matices is empty!");
        if(A.size()!=B.size()||A[0].size()!=B[0].size())
                throw runtime_error("the dimensions1 of the two matrices are not same!");

        C.resize(A.size());
        for(int i=0;i!=C.size();++i)
                C[i].resize(A[0].size(),Type());

        for(int i=0;i!=A.size();++i)
                for(int j=0;j!=A[i].size();++j)
                        C[i][j]=A[i][j]-B[i][j];
}

template<class Type>
void matrixAdd(const vector< vector<Type> >& A, const vector< vector<Type> >& B, vector< vector<Type> >& C)
{
        if(A.size()==0||B.size()==0)
                throw runtime_error("one of matices is empty!");
        if(A.size()!=B.size()||A[0].size()!=B[0].size())
                throw runtime_error("the dimensions22 of the two matrices are not same!");

        C.resize(A.size());
        for(int i=0;i!=C.size();++i)
                C[i].resize(A[0].size(),Type());

        for(int i=0;i!=A.size();++i)
                for(int j=0;j!=A[i].size();++j)
                        C[i][j]=A[i][j]+B[i][j];
}

template<class Type>
void dimensionWithPowerOf2(vector< vector<Type> >& A,vector< vector<Type> >& B)
{
        int maxDimension=A.size()<A[0].size()?A[0].size():A.size();
        maxDimension=maxDimension<B[0].size()?B[0].size():maxDimension;

        int index=0;
        while(pow(2,index)<maxDimension)
                index++;

        maxDimension=pow(2,index);

        A.resize(maxDimension);
        for(int i=0;i!=maxDimension;++i)
                A[i].resize(maxDimension,Type()); 

        B.resize(maxDimension);
        for(int i=0;i!=maxDimension;++i)
                B[i].resize(maxDimension,Type());
}

    template<class Type>
void strassenRecur(const vector< vector<Type> >& A, const vector< vector<Type> >& B,vector< vector<Type> >& C)
{
        C.resize(A.size());
        for(int i=0;i!=A.size();++i)
                C[i].resize(A.size(),Type());

        if(A.size()==128){
                directMultiply(A,B,C);
                return;
        }  // save much much more time than setting A.size()==1.

        int mid=A.size()/2;

//      equally divide matrix A to four sub matrix
        vector< vector<Type> > A11;
        subMatrices(0,mid,0,mid,A,A11);
        vector< vector<Type> > A12;
        subMatrices(0,mid,mid,A.size(),A,A12);
        vector< vector<Type> > A21;
        subMatrices(mid,A.size(),0,mid,A,A21);
        vector< vector<Type> > A22;
        subMatrices(mid,A.size(),mid,A.size(),A,A22);

//     equally divide matrix B to four sub matrix
        vector< vector<Type> > B11;
        subMatrices(0,mid,0,mid,B,B11);
        vector< vector<Type> > B12;
        subMatrices(0,mid,mid,B.size(),B,B12);
        vector< vector<Type> > B21;
        subMatrices(mid,B.size(),0,mid,B,B21);
        vector< vector<Type> > B22;
        subMatrices(mid,B.size(),mid,B.size(),B,B22);

//     calculate P1=A11*S1=A11*(B12-B22)
        vector< vector<Type> > s1;
        matrixSubtract(B12,B22,s1);
        vector< vector<Type> > p1;
        strassenRecur(A11,s1,p1);

//    calculate P2=S2*B22=(A11+A12)*B22
        vector< vector<Type> > s2;
        matrixAdd(A11,A12,s2);
        vector< vector<Type> > p2;
        strassenRecur(s2,B22,p2);

//      calculate P3=S3*B11=(A21+A22)*B11
        vector< vector<Type> > s3;
        matrixAdd(A21,A22,s3);
        vector< vector<Type> > p3;
        strassenRecur(s3,B11,p3);

//      calculate P4=A22*S4=A22*(B21-B11)
        vector< vector<Type> > s4;
        matrixSubtract(B21,B11,s4);
        vector< vector<Type> > p4;
        strassenRecur(A22,s4,p4);

//    calcualte P5=S5*S6=(A11+A22)*(B11+B22)
        vector< vector<Type> > s5;
        matrixAdd(A11,A22,s5);
        vector< vector<Type> > s6;
        matrixAdd(B11,B22,s6);
        vector< vector<Type> > p5;
        strassenRecur(s5,s6,p5);

//     calculate P6=S7*S8=(A12+A22)*(B21+B22)
        vector< vector<Type> > s7;
        matrixSubtract(A12,A22,s7);
        vector< vector<Type> > s8;
        matrixAdd(B21,B22,s8);
        vector< vector<Type> > p6;
        strassenRecur(s7,s8,p6);

//    calculate P7=S9*S10=(A11+A21)*(B11+B12)
        vector< vector<Type> > s9;
        matrixSubtract(A11,A21,s9);
        vector< vector<Type> > s10;
        matrixAdd(B11,B12,s10);
        vector< vector<Type> > p7;
        strassenRecur(s9,s10,p7);

        vector< vector<Type> > tmp1;
        vector< vector<Type> > tmp2;

//       C11=P5+P4-P2=P6
        matrixAdd(p5,p4,tmp1);
        matrixSubtract(tmp1,p2,tmp2);
        matrixAdd(tmp2,p6,tmp1);
        for(int i=0;i!=mid;++i)
                for(int j=0;j!=mid;++j)
                        C[i][j]=tmp1[i][j];
//      C12=P1+P2
        matrixAdd(p1,p2,tmp1);
        for(int i=0;i!=mid;++i)
                for(int j=0;j!=mid;++j)
                        C[i][j+mid]=tmp1[i][j];
//      C21=P3+P4
        matrixAdd(p3,p4,tmp1);
        for(int i=0;i!=mid;++i)
                for(int j=0;j!=mid;++j)
                        C[i+mid][j]=tmp1[i][j];

//     C22=P5+P1-P3-P7
        matrixAdd(p5,p1,tmp1);
        matrixSubtract(tmp1,p3,tmp2);
        matrixSubtract(tmp2,p7,tmp1);
        for(int i=0;i!=mid;++i)
                for(int j=0;j!=mid;++j)
                        C[i+mid][j+mid]=tmp1[i][j];
 }

// to calculate C=A*B   
template<class Type>
void strassen( vector< vector<Type> >& A,vector< vector<Type> >& B, vector< vector<Type> >& C)
{
        if(A.size()==0 || B.size()==0)
                throw runtime_error("one of the two matrices is empty!");
        if(A[0].size()!=B.size())
                throw runtime_error("The column number of the first matrice is unequal to the row number of the second matrice!");

        C.resize(A.size());
        for(int i=0;i!=C.size();++i)
                C[i].resize(B[0].size(),Type());

        // extend the dimensions of matrices A and B to power of 2. But A and B will be changed!.
        dimensionWithPowerOf2(A,B);

        // to calculate A*B(the dimensions of A and B are powers of 2)
        vector<vector<Type> > multiplyResult;
        strassenRecur(A,B,multiplyResult);

        // to use C to store the actual result of A*B 
        for(int i=0;i!=C.size();++i)
                for(int j=0;j!=C[i].size();++j)
                        C[i][j]=multiplyResult[i][j];
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值