关于Strassen算法递归出口的讨论(C语言)

https://blog.csdn.net/2303_79007499/article/details/136608019?spm=1001.2014.3001.5501

上面是Strassen算法的C语言代码。

我们可以看到,对于递归出口,我们可以进行调节,最开始我设置为2的时候,竟然运行时间大大落后与普通方法,所以我增大了数值,运行时间得以下降。

void S_Mul(int **a,int **b,int **c,int size)
{
    int half=size/2;
    if(size<=512 || size % 2 != 0)NormalMul(a,b,c,size);
    else{
        int **a11,**a12,**a21,**a22;
        int **b11,**b12,**b21,**b22;
        int **c11,**c12,**c21,**c22;
        int **aa,**bb;
        int **p1,**p2,**p3,**p4,**p5,**p6,**p7;
        //开辟空间
        a11=(int**)malloc(sizeof(int*)*half);
        a12=(int**)malloc(sizeof(int*)*half);
        a21=(int**)malloc(sizeof(int*)*half);
        a22=(int**)malloc(sizeof(int*)*half);
        b11=(int**)malloc(sizeof(int*)*half);
        b12=(int**)malloc(sizeof(int*)*half);
        b21=(int**)malloc(sizeof(int*)*half);
        b22=(int**)malloc(sizeof(int*)*half);
        c11=(int**)malloc(sizeof(int*)*half);
        c12=(int**)malloc(sizeof(int*)*half);
        c21=(int**)malloc(sizeof(int*)*half);
        c22=(int**)malloc(sizeof(int*)*half);
        aa=(int**)malloc(sizeof(int*)*half);
        bb=(int**)malloc(sizeof(int*)*half);
        p1=(int**)malloc(sizeof(int*)*half);
        p2=(int**)malloc(sizeof(int*)*half);
        p3=(int**)malloc(sizeof(int*)*half);
        p4=(int**)malloc(sizeof(int*)*half);
        p5=(int**)malloc(sizeof(int*)*half);
        p6=(int**)malloc(sizeof(int*)*half);
        p7=(int**)malloc(sizeof(int*)*half);

        int i,j;
        for(i=0;i<half;i++){
            a11[i]=(int*)malloc(sizeof(int)*half);
            a12[i]=(int*)malloc(sizeof(int)*half);
            a21[i]=(int*)malloc(sizeof(int)*half);
            a22[i]=(int*)malloc(sizeof(int)*half);
            b11[i]=(int*)malloc(sizeof(int)*half);
            b12[i]=(int*)malloc(sizeof(int)*half);
            b21[i]=(int*)malloc(sizeof(int)*half);
            b22[i]=(int*)malloc(sizeof(int)*half);
            c11[i]=(int*)malloc(sizeof(int)*half);
            c12[i]=(int*)malloc(sizeof(int)*half);
            c21[i]=(int*)malloc(sizeof(int)*half);
            c22[i]=(int*)malloc(sizeof(int)*half);
            aa[i]=(int*)malloc(sizeof(int)*half);
            bb[i]=(int*)malloc(sizeof(int)*half);
            p1[i]=(int*)malloc(sizeof(int)*half);
            p2[i]=(int*)malloc(sizeof(int)*half);
            p3[i]=(int*)malloc(sizeof(int)*half);
            p4[i]=(int*)malloc(sizeof(int)*half);
            p5[i]=(int*)malloc(sizeof(int)*half);
            p6[i]=(int*)malloc(sizeof(int)*half);
            p7[i]=(int*)malloc(sizeof(int)*half);
        }

        //分割数组
        for(i=0;i<half;i++){
            for(j=0;j<half;j++){
                a11[i][j]=a[i][j];
                a12[i][j]=a[i][j+half];
                a21[i][j]=a[i+half][j];
                a22[i][j]=a[i+half][j+half];

                b11[i][j]=b[i][j];
                b12[i][j]=b[i][j+half];
                b21[i][j]=b[i+half][j];
                b22[i][j]=b[i+half][j+half];
            }
        }

        //开始计算
        Minus(b12,b22,bb,half);
        S_Mul(a11,bb,p1,half);

        Add(a11,a12,aa,half);
        S_Mul(aa,b22,p2,half);

        Add(a21,a22,aa,half);
        S_Mul(aa,b11,p3,half);

        Minus(b21,b11,bb,half);
        S_Mul(a22,bb,p4,half);

        Add(a11,a22,aa,half);
        Add(b11,b22,bb,half);
        S_Mul(aa,bb,p5,half);

        Minus(a12,a22,aa,half);
        Add(b21,b22,bb,half);
        S_Mul(aa,bb,p6,half);

        Minus(a11,a21,aa,half);
        Add(b11,b12,bb,half);
        S_Mul(aa,bb,p7,half);

        //c11
        Add(p5,p4,c11,half);
        Minus(c11,p2,c11,half);
        Add(c11,p6,c11,half);
        //c12
        Add(p1,p2,c12,half);
        //c21
        Add(p3,p4,c21,half);
        //c22
        Add(p5,p1,c22,half);
        Minus(c22,p3,c22,half);
        Minus(c22,p7,c22,half);

        //合并
        for(i=0;i<half;i++){
            for(j=0;j<half;j++){
                c[i][j]=c11[i][j];
                c[i][j+half]=c12[i][j];
                c[i+half][j]=c21[i][j];
                c[i+half][j+half]=c22[i][j];
            }
        }
        //释放内存
        for (i = 0; i < half;i++) {
			free(a11[i]);
			free(a12[i]);
			free(a21[i]);
			free(a22[i]); 				
			free(b11[i]);
			free(b12[i]);
			free(b21[i]);				
			free(b22[i]);				
			free(c11[i]);
			free(c12[i]);
			free(c21[i]);				
			free(c22[i]);				
			free(p1[i]);
			free(p2[i]);
			free(p3[i]);
			free(p4[i]);				
			free(p5[i]);
			free(p6[i]);
			free(p7[i]);				
			free(aa[i]);
			free(bb[i]);			
		}
		free(a11);
		free(a12);
		free(a21);
		free(a22);				
		free(b11);
		free(b12);
		free(b21);
		free(b22);				
		free(c11);
		free(c12);
		free(c21);
		free(c22);				
		free(p1);
		free(p2);
		free(p3);
		free(p4);
		free(p5);				
		free(p6);
		free(p7);				
		free(aa);				
		free(bb);
    }
}

上面是关于Strassen算法的有关于递归出口的源码,这里的递归出口我采取了size<=512的时候,采用普通方法,因为我发现当我取得小于512的时候,在2^11的规模及以下,Strassen算法的运行时间并不会小于普通的算法。下面我将比较不同递归出口下的运行时间(这里我将最小设置为32,因为再小运行时间就会很长):

我们知道Strassen算法的时间复杂度是ϴ(n^lg7),所以数据规模充分大的时候,其肯定是优于普通矩阵乘法的,但是数据规模较小的情况下,由于Strassen算法要多次进行矩阵赋值,所以并不优于普通方法。因此我尝试Strassen算法与普通方法相结合,在递归到矩阵规模足够小的情况下,使用普通方法作为递归出口。这个方法确实管用,运行时间大大缩减。

然后,我在想递归出口处的size<=n,时会停止递归而使用普通方法,那么n取什么值的时候,效益最大呢?首先,n肯定得是2的幂,因为递归过程中size是2的幂,因此只要考虑n是2的幂的情况。所以,我通过实验得到了上面的数据。那我们根据数据直接排除n<512的情况,对于n=512,数据规模(m)<=1024的情况下,两方法时间持平,在m>=2048时,Strassen算法体现出了较大优势。当n>=2048时,很容易分析出不如n=512的情况好,因为m<=n的时候,实际上用的都是普通方法。

那么根据分析n=512 or 1024时效率最好(n=512 or 1024 实际运行时间基本持平)

以上数据基于本人电脑得到,矩阵采用了随机赋值。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值