[算法]用C++实现Strassen方法求矩阵乘法(详细思路+代码+注释+测试)

前言

记录算法分析作业

学完《数据结构与算法分析(C++版)》(第三版)16.3.3节Strassen矩阵相乘的算法流程后,用C++实现Strassen方法求矩阵乘法

参考了这个博客的思路link

Strassen矩阵相乘的算法,相比起普通算法,只是少了一次乘法,时间复杂度却少很多。由此可见,一个细小的差别说不定就会导致后果差别很大呀。(跑题~~

正文

以下是实现过程

1. 实现矩阵加法功能

//矩阵加法
void Matrix_Sum(int n, int** MatrixA, int** MatrixB, int** MatrixSum) {
	for (int i = 0; i < n; i++)
		for (int j = 0; j < n; j++)
			MatrixSum[i][j] = MatrixA[i][j] + MatrixB[i][j];
}

2. 实现矩阵减法功能

//矩阵减法
void Matrix_Sub(int n, int** MatrixA, int** MatrixB, int** MatrixSub) {
	for (int i = 0; i < n; i++)
		for (int j = 0; j < n; j++)
			MatrixSub[i][j] = MatrixA[i][j] - MatrixB[i][j];
}

3. 实现矩阵乘法功能

(这个需要注意一下,相比之下复杂一点,这个也就是我们平常计算矩阵乘法的算法,后续可以用来检验Strassen方法的正确性


//矩阵乘法
void Matrix_Mul(int n, int** MatrixA, int** MatrixB, int** MatrixMul) {
	for (int i = 0; i < n; i++)
		for (int j = 0; j < n; j++) {
			MatrixMul[i][j] = 0;
			for (int k = 0; k < n; k++)
				MatrixMul[i][j] = MatrixMul[i][j] + MatrixA[i][k] * MatrixB[k][j];
		}
			
}

4. 实现Strassen方法

用几个二维数组来存储数据,主要是各种二维数组的赋值繁琐了一点,写的时候要注意数组名字,思路很简单。

void Strassen(int N, int** MatrixA, int** MatrixB, int** MatrixC ) {
	int n = N / 2; //分治思想
	//初始化每个小矩阵的大小
	//数组的第二维一定要显示指定
	int** MatrixA11 = new int* [n];
	int** MatrixA12 = new int* [n];
	int** MatrixA21 = new int* [n];
	int** MatrixA22 = new int* [n];
	int** MatrixB11 = new int* [n];
	int** MatrixB12 = new int* [n];
	int** MatrixB21 = new int* [n];
	int** MatrixB22 = new int* [n];
	int** MatrixC11 = new int* [n];
	int** MatrixC12 = new int* [n];
	int** MatrixC21 = new int* [n];
	int** MatrixC22 = new int* [n];

	for (int i = 0; i < n ; i++) {  //分配连续内存
		MatrixA11[i] = new int[n];
		MatrixA12[i] = new int[n];
		MatrixA21[i] = new int[n];
		MatrixA22[i] = new int[n];
		MatrixB11[i] = new int[n];
		MatrixB12[i] = new int[n];
		MatrixB21[i] = new int[n];
		MatrixB22[i] = new int[n];
		MatrixC11[i] = new int[n];
		MatrixC12[i] = new int[n];
		MatrixC21[i] = new int[n];
		MatrixC22[i] = new int[n];
	}
	//为每个小矩阵赋值,将大矩阵分割为4个小矩阵
	for (int i = 0; i < n; i++)
		for (int j = 0; j < n; j++) {
			MatrixA11[i][j] = MatrixA[i][j];
			MatrixA12[i][j] = MatrixA[i][j + n];
			MatrixA21[i][j] = MatrixA[i + n][j];
			MatrixA22[i][j] = MatrixA[i + n][j + n];

			MatrixB11[i][j] = MatrixB[i][j];
			MatrixB12[i][j] = MatrixB[i][j + n];
			MatrixB21[i][j] = MatrixB[i + n][j];
			MatrixB22[i][j] = MatrixB[i + n][j + n];
		}		

	//存放加减法结果
	int** S1 = new int* [n];
	int** S2 = new int* [n];
	int** S3 = new int* [n];
	int** S4 = new int* [n];
	int** S5 = new int* [n];
	int** S6 = new int* [n];
	int** S7 = new int* [n];
	int** S8 = new int* [n];
	int** S9 = new int* [n];
	int** S10 = new int* [n];

	for (int i = 0; i < n; i++) {  //分配连续内存
		S1[i] = new int[n];
		S2[i] = new int[n];
		S3[i] = new int[n];
		S4[i] = new int[n];
		S5[i] = new int[n];
		S6[i] = new int[n];
		S7[i] = new int[n];
		S8[i] = new int[n];
		S9[i] = new int[n];
		S10[i] = new int[n];
	}
	//计算
	Matrix_Sub(n, MatrixA12, MatrixA22, S1);
	Matrix_Sum(n, MatrixB21, MatrixB22, S2);
	Matrix_Sum(n, MatrixA11, MatrixA22, S3);
	Matrix_Sum(n, MatrixB11, MatrixB22, S4);
	Matrix_Sub(n, MatrixA11, MatrixA21, S5);
	Matrix_Sum(n, MatrixB11, MatrixB12, S6);
	Matrix_Sum(n, MatrixA11, MatrixA12, S7);
	Matrix_Sub(n, MatrixB12, MatrixB22, S8);
	Matrix_Sub(n, MatrixB21, MatrixB11, S9);
	Matrix_Sum(n, MatrixA21, MatrixA22, S10);

	//存放乘法结果
	int** M1 = new int* [n];
	int** M2 = new int* [n];
	int** M3 = new int* [n];
	int** M4 = new int* [n];
	int** M5 = new int* [n];
	int** M6 = new int* [n];
	int** M7 = new int* [n];

	for (int i = 0; i < n; i++) {  //分配连续内存
		M1[i] = new int[n];
		M2[i] = new int[n];
		M3[i] = new int[n];
		M4[i] = new int[n];
		M5[i] = new int[n];
		M6[i] = new int[n];
		M7[i] = new int[n];
	}
	Matrix_Mul(n, S1, S2, M1);
	Matrix_Mul(n, S3, S4, M2);
	Matrix_Mul(n, S5, S6, M3);
	Matrix_Mul(n, S7, MatrixB22, M4);
	Matrix_Mul(n, MatrixA11, S8, M5);
	Matrix_Mul(n, MatrixA22, S9, M6);
	Matrix_Mul(n, S10, MatrixB11, M7);

	//finally
	//计算C
	Matrix_Sum(n, M1, M2, MatrixC11);
	Matrix_Sub(n, MatrixC11, M4, MatrixC11);
	Matrix_Sum(n, MatrixC11, M6, MatrixC11);

	Matrix_Sum(n, M4, M5, MatrixC12);
	Matrix_Sum(n, M6, M7, MatrixC21);

	Matrix_Sub(n, M2, M3, MatrixC22);
	Matrix_Sum(n, MatrixC22, M5, MatrixC22);
	Matrix_Sub(n, MatrixC22, M7, MatrixC22);

	//将C合并
	for (int i = 0; i < n; i++)
		for (int j = 0; j < n; j++) {
			MatrixC[i][j] = MatrixC11[i][j];
			MatrixC[i][j + n] = MatrixC12[i][j];
			MatrixC[i + n][j] = MatrixC21[i][j];
			MatrixC[i + n][j + n] = MatrixC22[i][j];
		}
}

5.初始化一个矩阵

//初始化
void Init_Matrix(int N, int** MatrixA) {
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			MatrixA[i][j] = rand() % 10 + 1;//产生1~10
		}
	}
}

6. 打印矩阵

//打印矩阵
void print(int** MatrixA, int n) {
	for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++)
			cout << MatrixA[i][j] << " ";
	cout << endl;
	}
	cout << endl;
}

7. 主函数,开始测试!

#include<iostream>
#include"Strassen.h"
using namespace std;

//初始化
void Init_Matrix(int N, int** MatrixA) {
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			MatrixA[i][j] = rand() % 10 + 1;//产生1~10
		}
	}
}

//打印矩阵
void print(int** MatrixA, int n) {
	for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++)
			cout << MatrixA[i][j] << " ";
	cout << endl;
	}
	cout << endl;
}

	int main() {
	//time
	clock_t startTime_For_Normal_Multipilication;
	clock_t endTime_For_Normal_Multipilication;

	clock_t startTime_For_Strassen;
	clock_t endTime_For_Strassen;

	time_t start, end;

	//准备工作
	int MatrixSize; //矩阵大小	
	cout << "请输入矩阵大小(必须是2的幂指数值(例如:32,64,512,..): ";
	cin >> MatrixSize;
	int N = MatrixSize;

	int** MatrixA = new int* [N];
	int** MatrixB = new int* [N];
	int** MatrixC = new int* [N];
	int** MatrixT = new int* [N]; //用于检测结果是否正确
	for (int i = 0; i < N; i++) {
		MatrixA[i] = new int[N];
		MatrixB[i] = new int[N];
		MatrixC[i] = new int[N];
		MatrixT[i] = new int[N];
	}
	Init_Matrix(N, MatrixA);
	Init_Matrix(N, MatrixB);

	//计算
	cout << "A矩阵为:" << endl;
	print(MatrixA, N);
	
	cout << "B矩阵为:" << endl;
	print(MatrixB, N);

	cout << "************用常用方法将矩阵相乘************" << endl;
	cout << "起始时间为:  " << (startTime_For_Normal_Multipilication = clock()) << endl;
	Matrix_Mul(N, MatrixA, MatrixB, MatrixT);
	cout << "结束时间为: " << (endTime_For_Normal_Multipilication = clock()) << endl;
	cout << "打印矩阵:" << endl;
	print(MatrixT, N);
	
	cout << "************用Strassen方法将矩阵相乘************" << endl;
	cout << "起始时间为: " << (startTime_For_Strassen = clock()) << endl;
	Strassen(N, MatrixA, MatrixB, MatrixC);
	cout << "结束时间为: " << (endTime_For_Strassen = clock()) << endl;
	cout << "打印矩阵:" << endl;
	print(MatrixC, N);

	//比较所用时间
	cout << "常用方法耗时:  " << (endTime_For_Normal_Multipilication - startTime_For_Normal_Multipilication)
		<< " Clocks.." << (endTime_For_Normal_Multipilication - startTime_For_Normal_Multipilication) / CLOCKS_PER_SEC << " Sec" << endl;
	cout << "Strassen方法耗时: " << (endTime_For_Strassen - startTime_For_Strassen) 
		<< " Clocks.." << (endTime_For_Strassen - startTime_For_Strassen) / CLOCKS_PER_SEC << " Sec\n";


	return 0;
}

运行结果

将数组定为1024这个大小,可以很明显地看出来差距

在这里插入图片描述

总结

1.思路简单,注意数字和字母就行。
2.主要是更加深刻地体会到了如何传参二维数组,感觉这个才是主要学到的(关于这一点,可以参考这个链接,写的很详细 使用参数传递二维数组

  • 8
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值