[C++]矩阵乘法Strassen算法-----代码实现

关于矩阵乘法的Strassen算法,这里不再叙述,推荐简书博客:
https://www.jianshu.com/p/6e21f8e872fd
本篇博客便是参考该简书的算法思想,采用C++进行代码实现,可作为上述简书的补充。

关于使用参数传递二维数组,可参考:
https://www.cnblogs.com/huipengly/p/8892110.html

#include<iostream>
using namespace std;
/*
可以传递二维数组作为参数,有两种方法,
1.change(int **a)直接传递一个指针进去
2.change(int a[][10])数组的第二维维度一定要显式指定
*/
//打印矩阵
void PrintMatrix(int **MatrixA, int N){
	for (int i = 0; i < N; i++){
		for (int j = 0; j < N; j++){
			cout << MatrixA[i][j] << " ";
		}
		cout << endl;
	}
}
//矩阵加法
void Matrix_Sum(int N, int** MatrixA, int** MatrixB, int** Sum_Matrix){

	for (int i = 0; i < N; i++){
		for (int j = 0; j < N; j++){
			Sum_Matrix[i][j] = MatrixA[i][j] + MatrixB[i][j];
		}
	}
}
//矩阵减法
void Matrix_Sub(int N, int** MatrixA, int** MatrixB, int** Sub_Matrix){
	for (int i = 0; i < N; i++){
		for (int j = 0; j < N; j++){
			Sub_Matrix[i][j] = MatrixA[i][j] -MatrixB[i][j];
		}
	}
}
//矩阵乘法
void Matrix_Mul(int N, int** MatrixA, int** MatrixB, int** Mul_Matrix){
	for (int i = 0; i < N; i++){
		for (int j = 0; j < N; j++){
			Mul_Matrix[i][j] = 0;
			for (int k = 0; k < N; k++){
				Mul_Matrix[i][j] = Mul_Matrix[i][j] + MatrixA[i][k] * MatrixB[k][j];
			}
		}
	}
}
//基于Strassen的矩阵乘法,输入矩阵A,B,C
void Strassen(int N, int** MatrixA, int** MatrixB, int** MatrixC){
	//新建矩阵
	int** MatrixA11;
	int** MatrixA12;
	int** MatrixA21;
	int** MatrixA22;

	int** MatrixB11;
	int** MatrixB12;
	int** MatrixB21;
	int** MatrixB22;

	int** MatrixC11;
	int** MatrixC12;
	int** MatrixC21;
	int** MatrixC22;
	//初始化每个小矩阵的大小
	MatrixA11 = new int*[N/2];//数组的第二维一定要显示指定
	MatrixA12 = new int*[N/2];
	MatrixA21 = new int*[N/2];
	MatrixA22 = new int*[N/2];

	MatrixB11 = new int*[N/2];
	MatrixB12 = new int*[N/2];
	MatrixB21 = new int*[N/2];
	MatrixB22 = new int*[N/2];

	MatrixC11 = new int*[N/2];
	MatrixC12 = new int*[N/2];
	MatrixC21 = new int*[N/2];
	MatrixC22 = new int*[N/2];
	for (int i = 0; i < N/2; i++)//分配连续内存
	{
		MatrixA11[i] = new int[N/2];
		MatrixA12[i] = new int[N / 2];
		MatrixA21[i] = new int[N / 2];
		MatrixA22[i] = new int[N / 2];

		MatrixB11[i] = new int[N / 2];
		MatrixB12[i] = new int[N / 2];
		MatrixB21[i] = new int[N / 2];
		MatrixB22[i] = new int[N / 2];
		
		MatrixC11[i] = new int[N / 2];
		MatrixC12[i] = new int[N / 2];
		MatrixC21[i] = new int[N / 2];
		MatrixC22[i] = new int[N / 2];
	}
	 //为每个小矩阵赋值,将大矩阵分割为4个小矩阵
	for (int i = 0; i < N / 2; i++){
		for (int j = 0; j < N / 2; j++){
			MatrixA11[i][j] = MatrixA[i][j];
			MatrixA12[i][j] = MatrixA[i][j + N / 2];
			MatrixA21[i][j] = MatrixA[i + N / 2][j];
			MatrixA22[i][j] = MatrixA[i + N / 2][j + N / 2];

			MatrixB11[i][j] = MatrixB[i][j];
			MatrixB12[i][j] = MatrixB[i][j + N / 2];
			MatrixB21[i][j] = MatrixB[i + N / 2][j];
			MatrixB22[i][j] = MatrixB[i + N / 2][j + N / 2];
		}
	}
   //做10个辅助矩阵S,计算加法
	int** MatrixS1=new int*[N/2];
	int** MatrixS2 = new int*[N/2];
	int** MatrixS3 = new int*[N/2];
	int** MatrixS4 = new int*[N / 2];
	int** MatrixS5 = new int*[N / 2];
	int** MatrixS6 = new int*[N / 2];
	int** MatrixS7 = new int*[N / 2];
	int** MatrixS8 = new int*[N / 2];
	int** MatrixS9 = new int*[N / 2];
	int** MatrixS10 = new int*[N / 2];
	
	for (int i = 0; i < N / 2; i++)//分配连续内存
	{
		MatrixS1[i] = new int[N / 2];
		MatrixS2[i] = new int[N / 2];
		MatrixS3[i] = new int[N / 2];
		MatrixS4[i] = new int[N / 2];
		MatrixS5[i] = new int[N / 2];
		MatrixS6[i] = new int[N / 2];
		MatrixS7[i] = new int[N / 2];
		MatrixS8[i] = new int[N / 2];
		MatrixS9[i] = new int[N / 2];
		MatrixS10[i] = new int[N / 2];
	}

	Matrix_Sub(N/2, MatrixB12, MatrixB22, MatrixS1);//S1 = B12 - B22
	Matrix_Sum(N / 2, MatrixA11, MatrixA12, MatrixS2);//S2 = A11 + A12
	Matrix_Sum(N / 2, MatrixA21, MatrixA22, MatrixS3);//S3 = A21 + A22
	Matrix_Sub(N / 2, MatrixB21, MatrixB11, MatrixS4);//S4 = B21 - B11
	Matrix_Sum(N / 2, MatrixA11, MatrixA22, MatrixS5);//S5 = A11 + A22
	Matrix_Sum(N / 2, MatrixB11, MatrixB22, MatrixS6);//S6 = B11 + B22
	Matrix_Sub(N / 2, MatrixA12, MatrixA22, MatrixS7);//S7 = A12 - A22
	Matrix_Sum(N / 2, MatrixB21, MatrixB22, MatrixS8);//S8 = B21 + B22
	Matrix_Sub(N / 2, MatrixA11, MatrixA21, MatrixS9);//S9 = A11 - A21
	Matrix_Sum(N / 2, MatrixB11, MatrixB12, MatrixS10);//S10 = B11 + B12

	//做7个辅助矩阵P,计算乘法
	int** MatrixP1 = new int*[N / 2];
	int** MatrixP2 = new int*[N / 2];
	int** MatrixP3 = new int*[N / 2];
	int** MatrixP4 = new int*[N / 2];
	int** MatrixP5 = new int*[N / 2];
	int** MatrixP6 = new int*[N / 2];
	int** MatrixP7 = new int*[N / 2];

	for (int i = 0; i < N / 2; i++)//分配连续内存
	{
		MatrixP1[i] = new int[N / 2];
		MatrixP2[i] = new int[N / 2];
		MatrixP3[i] = new int[N / 2];
		MatrixP4[i] = new int[N / 2];
		MatrixP5[i] = new int[N / 2];
		MatrixP6[i] = new int[N / 2];
		MatrixP7[i] = new int[N / 2];
	}
	Matrix_Mul(N / 2, MatrixA11, MatrixS1, MatrixP1);//P1 = A11 • S1
	Matrix_Mul(N / 2, MatrixS2, MatrixB22, MatrixP2);//P2 = S2 • B22
	Matrix_Mul(N / 2, MatrixS3, MatrixB11, MatrixP3);//P3 = S3 • B11
	Matrix_Mul(N / 2, MatrixA22, MatrixS4, MatrixP4);//P4 = A22 • S4
	Matrix_Mul(N / 2, MatrixS5, MatrixS6, MatrixP5);//P5 = S5 • S6
	Matrix_Mul(N / 2, MatrixS7, MatrixS8, MatrixP6);//P6 = S7 • S8
	Matrix_Mul(N / 2, MatrixS9, MatrixS10, MatrixP7);//P7 = S9 • S10

	//根据以上7个结果计算C矩阵
	Matrix_Sum(N / 2, MatrixP5, MatrixP4, MatrixC11); //C11 = P5 + P4 - P2 + P6
	Matrix_Sub(N / 2, MatrixC11, MatrixP2, MatrixC11);
	Matrix_Sum(N / 2, MatrixC11, MatrixP6, MatrixC11);
	Matrix_Sum(N / 2, MatrixP1, MatrixP2, MatrixC12);//C12 = P1 + P2
	Matrix_Sum(N / 2, MatrixP3, MatrixP4, MatrixC21);	//C21 = P3 + P4
	Matrix_Sum(N / 2, MatrixP5, MatrixP1, MatrixC22);	//C22 = P5 + P1 - P3 - P7
	Matrix_Sub(N / 2, MatrixC22, MatrixP3, MatrixC22);
	Matrix_Sub(N / 2, MatrixC22, MatrixP7, MatrixC22);
	//将C11,C12,C21,C21合并为C矩阵
	for (int i = 0; i < N / 2; i++){
		for (int j = 0; j < N / 2; j++){
			MatrixC[i][j] = MatrixC11[i][j];
			MatrixC[i][j+N/2] = MatrixC12[i][j];
			MatrixC[i+N/2][j] = MatrixC21[i][j];
			MatrixC[i+N/2][j+N/2] = MatrixC22[i][j];
		}
	}
}
//朴素矩阵相乘算法
void NormalMul_Matrix(int N, int **MatrixA, int **MatrixB, int **MatrixC){
	for (int i = 0; i < N; i++){
		for (int j = 0; j < N; j++){
			MatrixC[i][j] = 0;
			for (int k = 0; k < N; k++){
				MatrixC[i][j] = MatrixC[i][j]+MatrixA[i][k] * MatrixB[k][j];
			}
		}
	}
}

	//初始化矩阵A,B
void Init_Matrix(int N,int** MatrixA, int** MatrixB){
	for (int i = 0; i < N; i++){
		for (int j = 0; j < N; j++){
			MatrixA[i][j] = rand() % 10 + 1;//产生1~10
			MatrixB[i][j] = rand() % 10 + 1;
		}
	}
	/*
	C++中rand() 函数的用法
	1、rand()不需要参数,它会返回一个从0到最大随机数的任意整数,最大随机数的大小通常是固定的一个大整数。
	2、如果你要产生0~99这100个整数中的一个随机整数,可以表达为:int num = rand() % 100;
	这样,num的值就是一个0~99中的一个随机数了。
	3、如果要产生1~100,则是这样:int num = rand() % 100 + 1;
	4、总结来说,可以表示为:int num = rand() % n +a;
	其中的a是起始值,n-1+a是终止值,n是整数的范围。
	*/
}
//矩阵A,B测试用例
void Test_Matrix(int N, int** MatrixA, int** MatrixB){
	for (int i = 0; i < N; i++){
		for (int j = 0; j < N; j++){
			MatrixA[i][j] = 1;
			MatrixB[i][j] = 2;
		}
	}
	}

void main(){
	int N;
	cout << "请输入矩阵大小(必须是2的幂指数值(例如:32,64,512,..): ";
	cin >> N;	
	int** MatrixA = new int *[N];
	int**  MatrixB = new int *[N];
	int** MatrixC = new int *[N];//测试Strassen矩阵

	int** MatrixC1 = new int*[N];//测试朴素相乘矩阵
	for (int i = 0; i < N; i++)//分配连续内存
	{
		MatrixA[i] = new int[N];
		MatrixB[i] = new int[N];
		MatrixC[i] = new int[N];

		MatrixC1[i] = new int[N];
	}

	Init_Matrix(N, MatrixA, MatrixB);
	//Test_Matrix(N, MatrixA, MatrixB);
	cout << "A矩阵为:" << endl;
	PrintMatrix(MatrixA, N);
	cout << "B矩阵为:" << endl;
	PrintMatrix(MatrixB, N);	
	cout << "朴素矩矩阵为:" << endl;
	NormalMul_Matrix(N, MatrixA, MatrixB, MatrixC1);
	PrintMatrix(MatrixC1, N);
	cout << "Strassen矩阵为:" << endl;
	Strassen(N, MatrixA, MatrixB, MatrixC);
	PrintMatrix(MatrixC, N);
	system("pause");				//等待按任意键退出
	
}

运行结果如下:
在这里插入图片描述
关于Strassen算法的性能分析,可参考博客:
https://www.cnblogs.com/zhoutaotao/p/3963048.html
参考博客:
https://blog.csdn.net/zhuangxiaobin/article/details/36476769

  • 4
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值