Implementation of Strassen’s Algorithm for Matrix Multiplication

Strassen’s algorithm is not the most efficient algorithm for matrix multiplication, but it was the first algorithm that was theoretically faster than the naive algorithm. There is very good explanation and implementation of Strassen’s algorithm on Wikipedia.

However, the implementation of Strassen’s algorithm cannot be used directly, because it  just sets the base case of the divide-and-conquer to be 1×1 matrix, which would consume huge time cost for iteration. If set the base case to 2×2 matrix, which means 2×2 matrix and 1×1 matrix will be multiplied by naive algorithm, then the Strassen’s algorithm will be more efficient for matrices larger than 512×512.

When set base case to 2×2 matrix, then the Strassen’s algorithm will surpass naive algorithm for matrices larger than 512×512.

When set base case to 6×6 matrix, then the Strassen’s algorithm will surpass naive algorithm for matrices larger than 128×128.

/*------------------------------------------------------------------------------*/
// 	matrix_mult.cc -- Implementation of matrix multiplication with
// 			  Strassen's algorithm. 
//
// Compile this file with gcc command:
//     g++ -Wall -o matrix_mult matrix_mult.cc                                                
 
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <ctype.h>
#include <unistd.h>
#include <iostream>
#include <fstream>
#include <cmath>
#include <cstring>

using namespace std;

// This function allocates the matrix 
inline double** allocate_matrix(int n) 
{
    	double** mat=new double*[n];
	for(int i=0;i<n;++i)
	{
		mat[i]=new double[n];
		memset(mat[i],0,sizeof(double)*n);
	}
 
	return (mat);     // returns the pointer to the vector. 
}

/*------------------------------------------------------------------------------*/
// This function unallocates the matrix (frees memory)
inline void free_matrix(double **M, int n)
{
    for (int i = 0; i < n; i++) 
    { 
       delete [] M[i];
    } 

    delete [] M;         // frees the pointer /
    M = NULL;
}

/*------------------------------------------------------------------------------*/
// function to sum two matrices
inline void sum(double **a, double **b, double **result, int tam) {
 
    int i, j;
 
    for (i = 0; i < tam; i++) {
        for (j = 0; j < tam; j++) {
            result[i][j] = a[i][j] + b[i][j];
        }
    }
}
 
/*------------------------------------------------------------------------------*/
// function to subtract two matrices
inline void subtract(double **a, double **b, double **result, int tam) {
 
    int i, j;
 
    for (i = 0; i < tam; i++) {
        for (j = 0; j < tam; j++) {
            result[i][j] = a[i][j] - b[i][j];
        }
    }   
}

/*------------------------------------------------------------------------------*/
// naive method
void naive(double** A, double** B,double** C, int n)
{
	for (int i=0;i<n;i++)
    		for (int j=0;j<n;j++)
        		for(int k=0;k<n;k++)
            			C[i][j] += A[i][k]*B[k][j];
}

/*------------------------------------------------------------------------------*/
// Strassen's method
void strassen(double **a, double **b, double **c, int tam) 
{ 
    // Key observation: call naive method for matrices smaller than 2 x 2
    if(tam <= 4)
    {
	    naive(a,b,c,tam);
	    return;
    }
 
    // other cases are treated here:
    int newTam = tam/2;
    double **a11, **a12, **a21, **a22;
    double **b11, **b12, **b21, **b22;
    double **c11, **c12, **c21, **c22;
    double **p1, **p2, **p3, **p4, **p5, **p6, **p7;

    // memory allocation:
    a11 = allocate_matrix(newTam);
    a12 = allocate_matrix(newTam);
    a21 = allocate_matrix(newTam);
    a22 = allocate_matrix(newTam);

    b11 = allocate_matrix(newTam);
    b12 = allocate_matrix(newTam);
    b21 = allocate_matrix(newTam);
    b22 = allocate_matrix(newTam);

    c11 = allocate_matrix(newTam);
    c12 = allocate_matrix(newTam);
    c21 = allocate_matrix(newTam);
    c22 = allocate_matrix(newTam);

    p1 = allocate_matrix(newTam);
    p2 = allocate_matrix(newTam);
    p3 = allocate_matrix(newTam);
    p4 = allocate_matrix(newTam);
    p5 = allocate_matrix(newTam);
    p6 = allocate_matrix(newTam);
    p7 = allocate_matrix(newTam);

    double **aResult = allocate_matrix(newTam);
    double **bResult = allocate_matrix(newTam);

    //dividing the matrices in 4 sub-matrices:
    for (int i = 0; i < newTam; i++) {
        for (int j = 0; j < newTam; j++) {
            a11[i][j] = a[i][j];
            a12[i][j] = a[i][j + newTam];
            a21[i][j] = a[i + newTam][j];
            a22[i][j] = a[i + newTam][j + newTam];

            b11[i][j] = b[i][j];
            b12[i][j] = b[i][j + newTam];
            b21[i][j] = b[i + newTam][j];
            b22[i][j] = b[i + newTam][j + newTam];
         }
    }

       // Calculating p1 to p7:

       sum(a11, a22, aResult, newTam); // a11 + a22
       sum(b11, b22, bResult, newTam); // b11 + b22
       strassen(aResult, bResult, p1, newTam); // p1 = (a11+a22) * (b11+b22)

       sum(a21, a22, aResult, newTam); // a21 + a22
       strassen(aResult, b11, p2, newTam); // p2 = (a21+a22) * (b11)

       subtract(b12, b22, bResult, newTam); // b12 - b22
       strassen(a11, bResult, p3, newTam); // p3 = (a11) * (b12 - b22)

       subtract(b21, b11, bResult, newTam); // b21 - b11
       strassen(a22, bResult, p4, newTam); // p4 = (a22) * (b21 - b11)

       sum(a11, a12, aResult, newTam); // a11 + a12
       strassen(aResult, b22, p5, newTam); // p5 = (a11+a12) * (b22)   

       subtract(a21, a11, aResult, newTam); // a21 - a11
       sum(b11, b12, bResult, newTam); // b11 + b12
       strassen(aResult, bResult, p6, newTam); // p6 = (a21-a11) * (b11+b12)

       subtract(a12, a22, aResult, newTam); // a12 - a22
       sum(b21, b22, bResult, newTam); // b21 + b22
       strassen(aResult, bResult, p7, newTam); // p7 = (a12-a22) * (b21+b22)

       // calculating c21, c21, c11 e c22:

       sum(p3, p5, c12, newTam); // c12 = p3 + p5
       sum(p2, p4, c21, newTam); // c21 = p2 + p4

       sum(p1, p4, aResult, newTam); // p1 + p4
       sum(aResult, p7, bResult, newTam); // p1 + p4 + p7
       subtract(bResult, p5, c11, newTam); // c11 = p1 + p4 - p5 + p7

       sum(p1, p3, aResult, newTam); // p1 + p3
       sum(aResult, p6, bResult, newTam); // p1 + p3 + p6
       subtract(bResult, p2, c22, newTam); // c22 = p1 + p3 - p2 + p6

       // Grouping the results obtained in a single matrix:
       for (int i = 0; i < newTam ; i++) {
           for (int j = 0 ; j < newTam ; j++) {
               c[i][j] = c11[i][j];
               c[i][j + newTam] = c12[i][j];
               c[i + newTam][j] = c21[i][j];
               c[i + newTam][j + newTam] = c22[i][j];
           }
       }

       // deallocating memory (free):
       free_matrix(a11, newTam);
       free_matrix(a12, newTam);
       free_matrix(a21, newTam);
       free_matrix(a22, newTam);

       free_matrix(b11, newTam);
       free_matrix(b12, newTam);
       free_matrix(b21, newTam);
       free_matrix(b22, newTam);

       free_matrix(c11, newTam);
       free_matrix(c12, newTam);
       free_matrix(c21, newTam);
       free_matrix(c22, newTam);

       free_matrix(p1, newTam);
       free_matrix(p2, newTam);
       free_matrix(p3, newTam);
       free_matrix(p4, newTam);
       free_matrix(p5, newTam);
       free_matrix(p6, newTam);
       free_matrix(p7, newTam);
       free_matrix(aResult, newTam);
       free_matrix(bResult, newTam);
 
} // end of Strassen function

/*------------------------------------------------------------------------------*/
// Generate random matrices
void gen_matrix(double** M,int n)
{
	for(int i=0;i<n;++i)
	{
		for(int j=0;j<n;++j)
		{
			M[i][j]=rand()%100;
			//M[i][j]=1;
		}
	}
}

/*------------------------------------------------------------------------------*/
// print matrix M using specied fstream
void print_matrix(fstream& fs, double** M, int n)
{
	for(int i=0;i<n;++i)
	{
		for(int j=0;j<n;++j)
		{
			fs<<M[i][j]<<" ";
		}
		fs<<endl;
	}
	fs<<endl;
}

/*------------------------------------------------------------------------------*/
// record the generated matrix and the final product
void mat_mult_log(double** A, double** B,double** C,int n,char* file)
{
	fstream fs;
	fs.open(file,fstream::out);

	fs<<"Random Matrix A:"<<endl;
	print_matrix(fs,A,n);
	fs<<"Random Matrix B:"<<endl;
	print_matrix(fs,B,n);
	fs<<"C=A * B"<<endl;
	print_matrix(fs,C,n);

	fs.close();
}

/*------------------------------------------------------------------------------*/

int main(int argc, char** argv)
{
	srand(time(NULL));

	int mdim=2;	// matrix dimension
	char* output=NULL;
	bool is_strassen=false;
	int c;

	while ((c = getopt (argc, argv, "sn:o:")) != -1)
	{
		switch (c)
           	{
		case 's':
			is_strassen=true;
			break;
		case 'n':
             		mdim = pow((int)2,atoi(optarg)); // 2^n dimensions
             		break;
		case 'o':
             		output = optarg; // 2^n dimensions
             		break;
           	case '?':
             		if (optopt == 'n')
              			fprintf (stderr, "Option -%c requires an argument.\n", optopt);
             		else if (isprint (optopt))
               			fprintf (stderr, "Unknown option `-%c'.\n", optopt);
             		else
               			fprintf (stderr,
                        		"Unknown option character `\\x%x'.\n",
                        		optopt);
             		return 1;
           	default:
             		abort ();
           	}
	}

	// create new matrices
	double** A=allocate_matrix(mdim);
	double** B=allocate_matrix(mdim);
	double** C=allocate_matrix(mdim);
	gen_matrix(A,mdim);
	gen_matrix(B,mdim);

	// matrices multiplication
	if(is_strassen)
		strassen(A,B,C,mdim);
	else
		naive(A,B,C,mdim);

	if(output!=NULL)
		mat_mult_log(A,B,C,mdim,output);

	free_matrix(A,mdim);
	free_matrix(B,mdim);
	free_matrix(C,mdim);

	return 0;
}


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值