前言
记录算法分析作业
学完《数据结构与算法分析(C++版)》(第三版)16.3.3节Strassen矩阵相乘的算法流程后,用C++实现Strassen方法求矩阵乘法
参考了这个博客的思路link
Strassen矩阵相乘的算法,相比起普通算法,只是少了一次乘法,时间复杂度却少很多。由此可见,一个细小的差别说不定就会导致后果差别很大呀。(跑题~~
正文
以下是实现过程
1. 实现矩阵加法功能
//矩阵加法
void Matrix_Sum(int n, int** MatrixA, int** MatrixB, int** MatrixSum) {
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
MatrixSum[i][j] = MatrixA[i][j] + MatrixB[i][j];
}
2. 实现矩阵减法功能
//矩阵减法
void Matrix_Sub(int n, int** MatrixA, int** MatrixB, int** MatrixSub) {
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
MatrixSub[i][j] = MatrixA[i][j] - MatrixB[i][j];
}
3. 实现矩阵乘法功能
(这个需要注意一下,相比之下复杂一点,这个也就是我们平常计算矩阵乘法的算法,后续可以用来检验Strassen方法的正确性
//矩阵乘法
void Matrix_Mul(int n, int** MatrixA, int** MatrixB, int** MatrixMul) {
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++) {
MatrixMul[i][j] = 0;
for (int k = 0; k < n; k++)
MatrixMul[i][j] = MatrixMul[i][j] + MatrixA[i][k] * MatrixB[k][j];
}
}
4. 实现Strassen方法
用几个二维数组来存储数据,主要是各种二维数组的赋值繁琐了一点,写的时候要注意数组名字,思路很简单。
void Strassen(int N, int** MatrixA, int** MatrixB, int** MatrixC ) {
int n = N / 2; //分治思想
//初始化每个小矩阵的大小
//数组的第二维一定要显示指定
int** MatrixA11 = new int* [n];
int** MatrixA12 = new int* [n];
int** MatrixA21 = new int* [n];
int** MatrixA22 = new int* [n];
int** MatrixB11 = new int* [n];
int** MatrixB12 = new int* [n];
int** MatrixB21 = new int* [n];
int** MatrixB22 = new int* [n];
int** MatrixC11 = new int* [n];
int** MatrixC12 = new int* [n];
int** MatrixC21 = new int* [n];
int** MatrixC22 = new int* [n];
for (int i = 0; i < n ; i++) { //分配连续内存
MatrixA11[i] = new int[n];
MatrixA12[i] = new int[n];
MatrixA21[i] = new int[n];
MatrixA22[i] = new int[n];
MatrixB11[i] = new int[n];
MatrixB12[i] = new int[n];
MatrixB21[i] = new int[n];
MatrixB22[i] = new int[n];
MatrixC11[i] = new int[n];
MatrixC12[i] = new int[n];
MatrixC21[i] = new int[n];
MatrixC22[i] = new int[n];
}
//为每个小矩阵赋值,将大矩阵分割为4个小矩阵
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++) {
MatrixA11[i][j] = MatrixA[i][j];
MatrixA12[i][j] = MatrixA[i][j + n];
MatrixA21[i][j] = MatrixA[i + n][j];
MatrixA22[i][j] = MatrixA[i + n][j + n];
MatrixB11[i][j] = MatrixB[i][j];
MatrixB12[i][j] = MatrixB[i][j + n];
MatrixB21[i][j] = MatrixB[i + n][j];
MatrixB22[i][j] = MatrixB[i + n][j + n];
}
//存放加减法结果
int** S1 = new int* [n];
int** S2 = new int* [n];
int** S3 = new int* [n];
int** S4 = new int* [n];
int** S5 = new int* [n];
int** S6 = new int* [n];
int** S7 = new int* [n];
int** S8 = new int* [n];
int** S9 = new int* [n];
int** S10 = new int* [n];
for (int i = 0; i < n; i++) { //分配连续内存
S1[i] = new int[n];
S2[i] = new int[n];
S3[i] = new int[n];
S4[i] = new int[n];
S5[i] = new int[n];
S6[i] = new int[n];
S7[i] = new int[n];
S8[i] = new int[n];
S9[i] = new int[n];
S10[i] = new int[n];
}
//计算
Matrix_Sub(n, MatrixA12, MatrixA22, S1);
Matrix_Sum(n, MatrixB21, MatrixB22, S2);
Matrix_Sum(n, MatrixA11, MatrixA22, S3);
Matrix_Sum(n, MatrixB11, MatrixB22, S4);
Matrix_Sub(n, MatrixA11, MatrixA21, S5);
Matrix_Sum(n, MatrixB11, MatrixB12, S6);
Matrix_Sum(n, MatrixA11, MatrixA12, S7);
Matrix_Sub(n, MatrixB12, MatrixB22, S8);
Matrix_Sub(n, MatrixB21, MatrixB11, S9);
Matrix_Sum(n, MatrixA21, MatrixA22, S10);
//存放乘法结果
int** M1 = new int* [n];
int** M2 = new int* [n];
int** M3 = new int* [n];
int** M4 = new int* [n];
int** M5 = new int* [n];
int** M6 = new int* [n];
int** M7 = new int* [n];
for (int i = 0; i < n; i++) { //分配连续内存
M1[i] = new int[n];
M2[i] = new int[n];
M3[i] = new int[n];
M4[i] = new int[n];
M5[i] = new int[n];
M6[i] = new int[n];
M7[i] = new int[n];
}
Matrix_Mul(n, S1, S2, M1);
Matrix_Mul(n, S3, S4, M2);
Matrix_Mul(n, S5, S6, M3);
Matrix_Mul(n, S7, MatrixB22, M4);
Matrix_Mul(n, MatrixA11, S8, M5);
Matrix_Mul(n, MatrixA22, S9, M6);
Matrix_Mul(n, S10, MatrixB11, M7);
//finally
//计算C
Matrix_Sum(n, M1, M2, MatrixC11);
Matrix_Sub(n, MatrixC11, M4, MatrixC11);
Matrix_Sum(n, MatrixC11, M6, MatrixC11);
Matrix_Sum(n, M4, M5, MatrixC12);
Matrix_Sum(n, M6, M7, MatrixC21);
Matrix_Sub(n, M2, M3, MatrixC22);
Matrix_Sum(n, MatrixC22, M5, MatrixC22);
Matrix_Sub(n, MatrixC22, M7, MatrixC22);
//将C合并
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++) {
MatrixC[i][j] = MatrixC11[i][j];
MatrixC[i][j + n] = MatrixC12[i][j];
MatrixC[i + n][j] = MatrixC21[i][j];
MatrixC[i + n][j + n] = MatrixC22[i][j];
}
}
5.初始化一个矩阵
//初始化
void Init_Matrix(int N, int** MatrixA) {
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
MatrixA[i][j] = rand() % 10 + 1;//产生1~10
}
}
}
6. 打印矩阵
//打印矩阵
void print(int** MatrixA, int n) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++)
cout << MatrixA[i][j] << " ";
cout << endl;
}
cout << endl;
}
7. 主函数,开始测试!
#include<iostream>
#include"Strassen.h"
using namespace std;
//初始化
void Init_Matrix(int N, int** MatrixA) {
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
MatrixA[i][j] = rand() % 10 + 1;//产生1~10
}
}
}
//打印矩阵
void print(int** MatrixA, int n) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++)
cout << MatrixA[i][j] << " ";
cout << endl;
}
cout << endl;
}
int main() {
//time
clock_t startTime_For_Normal_Multipilication;
clock_t endTime_For_Normal_Multipilication;
clock_t startTime_For_Strassen;
clock_t endTime_For_Strassen;
time_t start, end;
//准备工作
int MatrixSize; //矩阵大小
cout << "请输入矩阵大小(必须是2的幂指数值(例如:32,64,512,..): ";
cin >> MatrixSize;
int N = MatrixSize;
int** MatrixA = new int* [N];
int** MatrixB = new int* [N];
int** MatrixC = new int* [N];
int** MatrixT = new int* [N]; //用于检测结果是否正确
for (int i = 0; i < N; i++) {
MatrixA[i] = new int[N];
MatrixB[i] = new int[N];
MatrixC[i] = new int[N];
MatrixT[i] = new int[N];
}
Init_Matrix(N, MatrixA);
Init_Matrix(N, MatrixB);
//计算
cout << "A矩阵为:" << endl;
print(MatrixA, N);
cout << "B矩阵为:" << endl;
print(MatrixB, N);
cout << "************用常用方法将矩阵相乘************" << endl;
cout << "起始时间为: " << (startTime_For_Normal_Multipilication = clock()) << endl;
Matrix_Mul(N, MatrixA, MatrixB, MatrixT);
cout << "结束时间为: " << (endTime_For_Normal_Multipilication = clock()) << endl;
cout << "打印矩阵:" << endl;
print(MatrixT, N);
cout << "************用Strassen方法将矩阵相乘************" << endl;
cout << "起始时间为: " << (startTime_For_Strassen = clock()) << endl;
Strassen(N, MatrixA, MatrixB, MatrixC);
cout << "结束时间为: " << (endTime_For_Strassen = clock()) << endl;
cout << "打印矩阵:" << endl;
print(MatrixC, N);
//比较所用时间
cout << "常用方法耗时: " << (endTime_For_Normal_Multipilication - startTime_For_Normal_Multipilication)
<< " Clocks.." << (endTime_For_Normal_Multipilication - startTime_For_Normal_Multipilication) / CLOCKS_PER_SEC << " Sec" << endl;
cout << "Strassen方法耗时: " << (endTime_For_Strassen - startTime_For_Strassen)
<< " Clocks.." << (endTime_For_Strassen - startTime_For_Strassen) / CLOCKS_PER_SEC << " Sec\n";
return 0;
}
运行结果
将数组定为1024这个大小,可以很明显地看出来差距
总结
1.思路简单,注意数字和字母就行。
2.主要是更加深刻地体会到了如何传参二维数组,感觉这个才是主要学到的(关于这一点,可以参考这个链接,写的很详细 使用参数传递二维数组