C语言实现Strassen算法及分析

本文介绍了矩阵乘法的传统暴力计算法和分治算法,重点讲解了Strassen算法,一种通过分治思想将矩阵乘法的时间复杂度从O(n^3)降低到O(n^log7)的高效算法,包括具体步骤和代码实现。
摘要由CSDN通过智能技术生成

首先介绍一下矩阵A,B相乘的几种方法:

1.暴力计算法:利用三重循环进行计算,时间复杂度为O(n^3)。

void NormalMul(int** a,int** b,int **c,int n)//普通方法
{
    int i,j,k;
    for(i=0;i<n;i++){
        for(j=0;j<n;j++){
            int m=0;
            for(k=0;k<n;k++)m+=a[i][k]*b[k][j];
            c[i][j]=m;
        }
    }
}

2.分治算法:其时间复杂度也是O(n^3),所以不在详细说。

3.Strassen算法:其本质上也是使用了分治的思想,但是将时间复杂度优化到了O(n^log7),在大规模矩阵乘法上有极大优势。

现在重新定义7个新矩阵

M1=(A11+A22)*(B11+B22)

M2=(A21+A22)*B11

M3=A11*(B12-B22)

M4=A22*(B21-B11)

M5=(A11+A12)*B22

M6=(A21-A11)*(B11+B12)

M7=(A12-A22)*(B21+B22)

结果矩阵C可以组合上述矩阵,如下

C11=M1+M4-M5+M7

C12=M3+M5

C21=M2+M4

C22=M1-M2+M3+M6

这时候共用了7次乘法,18次加减法运算. 写出递推公式T(n)=7T(n/2)+Θ(n^2). 最终结果是O(n^log7)=O(n^2.807)

void Add(int **a,int **b,int **c,int size)//计算a+b->c
{
    int i,j;
    for(i=0;i<size;i++){
        for(j=0;j<size;j++){
            c[i][j]=a[i][j]+b[i][j];
        }
    }
}

void Minus(int **a,int **b,int **c,int size)//计算a-b->c
{
    int i,j;
    for(i=0;i<size;i++){
        for(j=0;j<size;j++){
            c[i][j]=a[i][j]-b[i][j];
        }
    }
}

void S_Mul(int **a,int **b,int **c,int size)//Strassen算法
{
    int half=size/2;
    if(size==512)NormalMul(a,b,c,size);//这里条件可以改变
    else{
        int **a11,**a12,**a21,**a22;
        int **b11,**b12,**b21,**b22;
        int **c11,**c12,**c21,**c22;
        int **aa,**bb;
        int **p1,**p2,**p3,**p4,**p5,**p6,**p7;
        //开辟空间
        a11=(int**)malloc(sizeof(int*)*half);
        a12=(int**)malloc(sizeof(int*)*half);
        a21=(int**)malloc(sizeof(int*)*half);
        a22=(int**)malloc(sizeof(int*)*half);
        b11=(int**)malloc(sizeof(int*)*half);
        b12=(int**)malloc(sizeof(int*)*half);
        b21=(int**)malloc(sizeof(int*)*half);
        b22=(int**)malloc(sizeof(int*)*half);
        c11=(int**)malloc(sizeof(int*)*half);
        c12=(int**)malloc(sizeof(int*)*half);
        c21=(int**)malloc(sizeof(int*)*half);
        c22=(int**)malloc(sizeof(int*)*half);
        aa=(int**)malloc(sizeof(int*)*half);
        bb=(int**)malloc(sizeof(int*)*half);
        p1=(int**)malloc(sizeof(int*)*half);
        p2=(int**)malloc(sizeof(int*)*half);
        p3=(int**)malloc(sizeof(int*)*half);
        p4=(int**)malloc(sizeof(int*)*half);
        p5=(int**)malloc(sizeof(int*)*half);
        p6=(int**)malloc(sizeof(int*)*half);
        p7=(int**)malloc(sizeof(int*)*half);

        int i,j;
        for(i=0;i<half;i++){
            a11[i]=(int*)malloc(sizeof(int)*half);
            a12[i]=(int*)malloc(sizeof(int)*half);
            a21[i]=(int*)malloc(sizeof(int)*half);
            a22[i]=(int*)malloc(sizeof(int)*half);
            b11[i]=(int*)malloc(sizeof(int)*half);
            b12[i]=(int*)malloc(sizeof(int)*half);
            b21[i]=(int*)malloc(sizeof(int)*half);
            b22[i]=(int*)malloc(sizeof(int)*half);
            c11[i]=(int*)malloc(sizeof(int)*half);
            c12[i]=(int*)malloc(sizeof(int)*half);
            c21[i]=(int*)malloc(sizeof(int)*half);
            c22[i]=(int*)malloc(sizeof(int)*half);
            aa[i]=(int*)malloc(sizeof(int)*half);
            bb[i]=(int*)malloc(sizeof(int)*half);
            p1[i]=(int*)malloc(sizeof(int)*half);
            p2[i]=(int*)malloc(sizeof(int)*half);
            p3[i]=(int*)malloc(sizeof(int)*half);
            p4[i]=(int*)malloc(sizeof(int)*half);
            p5[i]=(int*)malloc(sizeof(int)*half);
            p6[i]=(int*)malloc(sizeof(int)*half);
            p7[i]=(int*)malloc(sizeof(int)*half);
        }

        //分割数组
        for(i=0;i<half;i++){
            for(j=0;j<half;j++){
                a11[i][j]=a[i][j];
                a12[i][j]=a[i][j+half];
                a21[i][j]=a[i+half][j];
                a22[i][j]=a[i+half][j+half];

                b11[i][j]=b[i][j];
                b12[i][j]=b[i][j+half];
                b21[i][j]=b[i+half][j];
                b22[i][j]=b[i+half][j+half];
            }
        }

        //开始计算
        Minus(b12,b22,bb,half);
        S_Mul(a11,bb,p1,half);

        Add(a11,a12,aa,half);
        S_Mul(aa,b22,p2,half);

        Add(a21,a22,aa,half);
        S_Mul(aa,b11,p3,half);

        Minus(b21,b11,bb,half);
        S_Mul(a22,bb,p4,half);

        Add(a11,a22,aa,half);
        Add(b11,b22,bb,half);
        S_Mul(aa,bb,p5,half);

        Minus(a12,a22,aa,half);
        Add(b21,b22,bb,half);
        S_Mul(aa,bb,p6,half);

        Minus(a11,a21,aa,half);
        Add(b11,b12,bb,half);
        S_Mul(aa,bb,p7,half);

        //c11
        Add(p5,p4,c11,half);
        Minus(c11,p2,c11,half);
        Add(c11,p6,c11,half);
        //c12
        Add(p1,p2,c12,half);
        //c21
        Add(p3,p4,c21,half);
        //c22
        Add(p5,p1,c22,half);
        Minus(c22,p3,c22,half);
        Minus(c22,p7,c22,half);

        //合并
        for(i=0;i<half;i++){
            for(j=0;j<half;j++){
                c[i][j]=c11[i][j];
                c[i][j+half]=c12[i][j];
                c[i+half][j]=c21[i][j];
                c[i+half][j+half]=c22[i][j];
            }
        }
        //释放内存
        for (i = 0; i < half;i++) {
			free(a11[i]);
			free(a12[i]);
			free(a21[i]);
			free(a22[i]); 				
			free(b11[i]);
			free(b12[i]);
			free(b21[i]);				
			free(b22[i]);				
			free(c11[i]);
			free(c12[i]);
			free(c21[i]);				
			free(c22[i]);				
			free(p1[i]);
			free(p2[i]);
			free(p3[i]);
			free(p4[i]);				
			free(p5[i]);
			free(p6[i]);
			free(p7[i]);				
			free(aa[i]);
			free(bb[i]);			
		}
		free(a11);
		free(a12);
		free(a21);
		free(a22);				
		free(b11);
		free(b12);
		free(b21);
		free(b22);				
		free(c11);
		free(c12);
		free(c21);
		free(c22);				
		free(p1);
		free(p2);
		free(p3);
		free(p4);
		free(p5);				
		free(p6);
		free(p7);				
		free(aa);				
		free(bb);
    }
}

以上是Strassen算法的代码。

  • 15
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值