矩阵相乘的strassen算法_Strassen矩阵乘法

1 问题描述

求矩阵A,B相乘的结果C

2 分析

2.1 传统解法

直接根据矩阵乘法的定义来遍历计算。

c++语言代码:

void matrixMul(int** A,int** B,int** C,int m,int b,int n){
    for(int i=0;i<m;i++){
        for(int j=0;j<n;j++){
            *((int*)C+i*n+j)=0;
            for(int k=0;k<b;k++){
                *((int*)C+i*n+j)+=*((int*)A+i*b+k)*(*((int*)B+k*n+j));
            }
        }
    }
}
void test3(){
    int A[2][3]={{1,2,3},{1,2,3}};
    int B[3][2]={{1,2},{1,2},{1,1}};
    int C[2][2]={};

    matrixMul((int**)A,(int**)B,(int**)C,2,3,2);
    for(int i=0;i<2;i++){
        for(int j=0;j<2;j++){
            printf("%d ",C[i][j]);
        }
        printf("n");
    }
}

2.2 分治算法-Strassen

与整数乘法类似,可以将矩阵A,B和C中每一矩阵都分块成4个大小相等的子矩阵

ed3c5a2ee8f776f9c90eb989e86ae581.png

依次将矩阵的乘法按照上图拆分为最小单元的矩阵计算,即只有一个元素,然后再返回。

在计算的时候,如果直接分解然后计算,需要进行8次乘法运算。strassen算法可以将8次乘法降为7次乘法,从而降低时间复杂度:

fd8098e3eef72cd3b2b084d301717436.png

最终达到降低复杂度为

具体见代码

3 代码实现

#include <cmath>
#include <stdlib.h>
#include <cstdio>

#define SIZE 16

using namespace std;


/* For taking input from standard input(keyboard)*/
void ReadMatrix(double A[][SIZE],int N)
{
    int i,j;

    for(i=0; i<N; i++)
    {
        for(j=0; j<N; j++)
        {
            scanf("%lf", &A[i][j]);
        }
    }
}

/*For printing the matrix in standard output(console)*/
void WriteMatrix(double A[][SIZE], int N)
{
    int i, j;

    for(i=0; i<N; i++)
    {
        for(j=0; j<N; j++)
        {
            printf("%0.1lf t", A[i][j]);
        }
        printf("n");
    }
}

/*This function will add two square matrix*/
void MatrixAdd(double A[][SIZE], double B[][SIZE], double Result[][SIZE], int N)
{
    int i, j;

    for(i=0; i< N; i++)
    {
        for(j=0; j<N; j++)
        {
            Result[i][j] = A[i][j] + B[i][j];
        }
    }

}

/*This function will subtract one  square matrix from another*/
void MatrixSubtrac(double A[][SIZE], double B[][SIZE], double Result[][SIZE], int N)
{
    int i, j;

    for(i=0; i< N; i++)
    {
        for(j=0; j<N; j++)
        {
            Result[i][j] = A[i][j] - B[i][j];
        }
    }
}


/*This is the strassen algorithm. (Divide and Conqure)*/
void StrassenAlgorithm(double A[][SIZE], double B[][SIZE], double C[][SIZE], int N){
    // trivial case: when the matrice is 1 X 1:
    if(N == 1)
    {
        C[0][0] = A[0][0] * B[0][0];
        return;
    }

    // other cases are treated here:
    else{
        int Divide  = (int)(N/2);

        double A11[SIZE][SIZE], A12[SIZE][SIZE], A21[SIZE][SIZE], A22[SIZE][SIZE];
        double B11[SIZE][SIZE], B12[SIZE][SIZE], B21[SIZE][SIZE], B22[SIZE][SIZE];
        double C11[SIZE][SIZE], C12[SIZE][SIZE], C21[SIZE][SIZE], C22[SIZE][SIZE];
        double P1[SIZE][SIZE], P2[SIZE][SIZE], P3[SIZE][SIZE], P4[SIZE][SIZE], P5[SIZE][SIZE], P6[SIZE][SIZE], P7[SIZE][SIZE];
        double AResult[SIZE][SIZE], BResult[SIZE][SIZE];

        int i, j;

        //dividing the matrices in 4 sub-matrices:
        for (i = 0; i < Divide; i++)
        {
            for (j = 0; j < Divide; j++)
            {
                A11[i][j] = A[i][j];
                A12[i][j] = A[i][j + Divide];
                A21[i][j] = A[i + Divide][j];
                A22[i][j] = A[i + Divide][j + Divide];

                B11[i][j] = B[i][j];
                B12[i][j] = B[i][j + Divide];
                B21[i][j] = B[i + Divide][j];
                B22[i][j] = B[i + Divide][j + Divide];
            }
        }

        // Calculating p1 to p7:
        /*For details -- Introduction to Algorithms 3rd Edition by CLRS*/

        MatrixAdd(A11, A22, AResult, Divide);   // a11 + a22
        MatrixAdd(B11, B22, BResult, Divide);   // b11 + b22
        StrassenAlgorithm(AResult, BResult, P1, Divide);  // p1 = (a11+a22) * (b11+b22)

        MatrixAdd(A21, A22, AResult, Divide);   // a21 + a22
        StrassenAlgorithm(AResult, B11, P2, Divide); // p2 = (a21+a22) * (b11)

        MatrixSubtrac(B12, B22, BResult, Divide); // b12 - b22
        StrassenAlgorithm(A11, BResult, P3, Divide); // p3 = (a11) * (b12 - b22)

        MatrixSubtrac(B21, B11, BResult, Divide); // b21 - b11
        StrassenAlgorithm(A22, BResult, P4, Divide); // p4 = (a22) * (b21 - b11)

        MatrixAdd(A11, A12, AResult, Divide); // a11 + a12
        StrassenAlgorithm(AResult, B22, P5, Divide); // p5 = (a11+a12) * (b22)

        MatrixSubtrac(A21, A11, AResult, Divide); // a21 - a11
        MatrixAdd(B11, B12, BResult, Divide); // b11 + b12
        StrassenAlgorithm(AResult, BResult, P6, Divide); // p6 = (a21-a11) * (b11+b12)

        MatrixSubtrac(A12, A22, AResult, Divide); // a12 - a22
        MatrixAdd(B21, B22, BResult, Divide); // b21 + b22
        StrassenAlgorithm(AResult, BResult, P7, Divide); // p7 = (a12-a22) * (b21+b22)

        // calculating c21, c21, c11 e c22:

        MatrixAdd(P3, P5, C12, Divide); // c12 = p3 + p5
        MatrixAdd(P2, P4, C21, Divide); // c21 = p2 + p4

        MatrixAdd(P1, P4, AResult, Divide); // p1 + p4
        MatrixAdd(AResult, P7, BResult, Divide); // p1 + p4 + p7
        MatrixSubtrac(BResult, P5, C11, Divide); // c11 = p1 + p4 - p5 + p7

        MatrixAdd(P1, P3, AResult, Divide); // p1 + p3
        MatrixAdd(AResult, P6, BResult, Divide); // p1 + p3 + p6
        MatrixSubtrac(BResult, P2, C22, Divide); // c22 = p1 + p3 - p2 + p6


        // Grouping the results obtained in a single matrice:

        for (i = 0; i < Divide ; i++)
        {
            for (j = 0 ; j < Divide ; j++)
            {
                C[i][j] = C11[i][j];
                C[i][j + Divide] = C12[i][j];
                C[i + Divide][j] = C21[i][j];
                C[i + Divide][j + Divide] = C22[i][j];
            }
        }

    }

}

/*The main function*/
void test4(){
    double A[SIZE][SIZE], B[SIZE][SIZE], C[SIZE][SIZE];
    int i,j;
    int N,M,Count = 0;

    printf("What Is The Dimention: ");
    scanf("%d",&N);

    M = N;

    printf("Matrix A:n");
    ReadMatrix(A,M);
    printf("Matrix B:n");
    ReadMatrix(B,M);

    if(M > 1)
    {

        while(M>=2)
        {
            M/=2;
            Count++;
        }

        M = N;

        if(M != (pow(2.0,Count)))
        {
            N = pow(2.0,Count+1);

            for(i=0; i<N; i++)
            {
                for(j=0; j<N; j++)
                {
                    if((i>=M) || (j>=M))
                    {
                        A[i][j] = 0.0;
                        B[i][j] = 0.0;
                    }
                }
            }
        }
    }

    StrassenAlgorithm(A,B,C,N); // StrassenAlgorithm called here

    printf("Matrix A:nn");
    WriteMatrix(A,M);
    printf("Matrix B:nn");
    WriteMatrix(B,M);
    printf("The Product Of These Two Matrix:nn");
    WriteMatrix(C,M);
}

这份代码只能计算方阵,而且使用宏定义来预定矩阵最大尺寸,很丑陋。只是用来揭示Strassen的计算过程。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值