c++实现Strassen算法 与朴素算法对比及优化
- 编程实现普通的矩阵乘法;
- 编程实现 Strassen’s algorithm;
- 在不同数据规模情况下(数据规模)下,比较两种算法的运行时间各是多少;
- 修改 Strassen’s algorithm,使之适应矩阵规模 N 不是 2 的幂的情况;
- 改进后的算法与 2 中的算法在相同数据规模下进行比较。
设计实验
将给出的实验数据写入 datas.txt 文件,为控制变量,两种算法分别从文件中读取数据。记录从数据读入到计算完成所用时间,取 5 次实验得到的平均值,利用数据作出折线图。
一、普通矩阵的代码分析
#include<iostream>
#include<time.h>
#include<stdlib.h>
#include<fstream>
#include<iomanip>
using namespace std;
int main()
{
std::ios::sync_with_stdio(false);
std::cin.tie(0);
int N, M;
cin >> N >> M;
int a[M][M];
int b[M][M];
ifstream fin("datas.txt");
auto start = chrono::high_resolution_clock::now(); //计时开始
while (N--)
{
for (int i = 0; i < M; i++)
{
for (int j = 0; j < M;j++)
{
fin >> a[i][j];
}
}
for (int i = 0; i < M; i++)
{
for (int j = 0; j < M;j++)
{
fin >> b[i][j];
}
}
for (int i = 0; i < M; i++)
{
for (int j = 0; j < M; j++)
{
int cij = 0;
for (int k = 0; k < M; k++)
{
cij += a[i][k] * b[k][j];
}
cout << cij << " \n"[j == M - 1];
}
}
}
auto end = chrono::high_resolution_clock::now(); //计时结束
chrono::duration<double> diff = end - start;
cout << fixed << setprecision(10) << diff.count() << endl;
return 0;
}
二、Strassen 算法的代码分析
#include <iostream>
#include<time.h>
#include<stdlib.h>
#include<fstream>
#include<iomanip>
using namespace std;
void minusm(int l, int **m, int **n, int **ans) //两矩阵减法
{
for (int i = 0; i < l; i++)
{
for (int j = 0; j < l; j++)
{
ans[i][j] = m[i][j] - n[i][j];
}
}
}
void addm(int l, int **m, int **n, int **ans) //两矩阵加法
{
for (int i = 0; i < l; i++)
{
for (int j = 0; j < l; j++)
{
ans[i][j] = m[i][j] + n[i][j];
}
}
}
void multim(int l, int **m, int **n, int **ans) //两矩阵乘法
{
for (int i = 0; i < l; i++)
{
for (int j = 0; j < l; j++)
{
ans[i][j] = 0;
for (int k = 0; k < l; k++)
{
ans[i][j] += m[i][k] * n[k][j];
}
}
}
}
void Strassen(int N, int **A, int **B, int **C) //strassen算法
{
if(N<=4)
{
multim(N, A, B, C);
}else
{
int **A11 = new int *[N / 2];
int **A12 = new int *[N / 2];
int **A21 = new int *[N / 2];
int **A22 = new int *[N / 2];
int **B11 = new int *[N / 2];
int **B12 = new int *[N / 2];
int **B21 = new int *[N / 2];
int **B22 = new int *[N / 2];
int **C11 = new int *[N / 2];
int **C12 = new int *[N / 2];
int **C21 = new int *[N / 2];
int **C22 = new int *[N / 2];
int **P1 = new int *[N / 2];
int **P2 = new int *[N / 2];
int **P3 = new int *[N / 2];
int **P4 = new int *[N / 2];
int **P5 = new int *[N / 2];
int **P6 = new int *[N / 2];
int **P7 = new int *[N / 2];
int **AR = new int *[N / 2];
int **BR = new int *[N / 2];
for (int i = 0; i < N / 2; i++)
{
A11[i] = new int[N / 2];
A12[i] = new int[N / 2];
A21[i] = new int[N / 2];
A22[i] = new int[N / 2];
B11[i] = new int[N / 2];
B12[i] = new int[N / 2];
B21[i] = new int[N / 2];
B22[i] = new int[N / 2];
C11[i] = new int[N / 2];
C12[i] = new int[N / 2];
C21[i] = new int[N / 2];
C22[i] = new int[N / 2];
P1[i] = new int[N / 2];
P2[i] = new int[N / 2];
P3[i] = new int[N / 2];
P4[i] = new int[N / 2];
P5[i] = new int[N / 2];
P6[i] = new int[N / 2];
P7[i] = new int[N / 2];
AR[i] = new int[N / 2];
BR[i] = new int[N / 2];
}
for (int i = 0; i < N / 2