Strassen矩阵乘法是一个时间复杂度小于
O
(
N
3
)
O(N^3)
O(N3)的矩阵乘法算法,通过把朴素分治矩阵乘法的八次对子矩阵乘积计算的递归操作通过一系列优化技巧减少到七次,从而得到少于
O
(
N
3
)
O(N^3)
O(N3)的时间复杂度,该算法的详细讲解参见算法导论(第三版)中译本p43-p47这里只给出C++实现而不重复阐述
C++代码:
#include <iostream>
#include <vector>
#include <random>
using namespace std;
void matrixAdd(vector<vector<int>>& left, size_t left_up_x, size_t left_up_y, size_t left_down_x, size_t left_down_y,vector<vector<int>>& right, size_t right_up_x, size_t right_up_y, vector<vector<int>>& result)
{
for (size_t i = 0; i <= left_down_y - left_up_y; ++i)
{
for (size_t j = 0; j <= left_down_x - left_up_x; ++j)
{
result[i][j] = left[i + left_up_y][j + left_up_x] + right[i + right_up_y][j + right_up_x];
}
}
}
void matrixSub(vector<vector<int>>& left, size_t left_up_x, size_t left_up_y, size_t left_down_x, size_t left_down_y, vector<vector<int>>& right, size_t right_up_x, size_t right_up_y, vector<vector<int>>& result)
{
for (size_t i = 0; i <= left_down_y - left_up_y; ++i)
{
for (size_t j = 0; j <= left_down_x - left_up_x; ++j)
{
result[i][j] = left[i + left_up_y][j + left_up_x] - right[i + right_up_y][j + right_up_x];
}
}
}
void extract(vector<vector<int>>& left, size_t left_up_x, size_t left_up_y, size_t left_down_x, size_t left_down_y, vector<vector<int>>& result)
{
for (size_t i = 0; i <= left_down_y - left_up_y; ++i)
{
for (size_t j = 0; j <= left_down_x - left_up_x; ++j)
{
result[i][j] = left[i + left_up_y][j + left_up_x];
}
}
}
void add(vector<vector<int>>& be_added, size_t left_up_x, size_t left_up_y, size_t left_down_x, size_t left_down_y, vector<vector<int>>& left)
{
for (size_t i = 0; i + left_up_y <= left_down_y; ++i)
{
for (size_t j = 0; j + left_up_x <= left_down_x; ++j)
{
be_added[i + left_up_y][j + left_up_x] = left[i][j];
}
}
}
void Strassen(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &P)
{
if (A.size() == 1 && B.size() == 1 && B[0].size() == 1)
{
P[0][0] = A[0][0] * B[0][0];
return;
}
size_t s = A.size();
size_t m = B[0].size();
if (A.size() % 2 == 1)
{
A.push_back(vector<int>(B.size(), 0));
}
if (B.size() % 2 == 1)
{
for (size_t i = 0; i < A.size(); ++i)
{
A[i].push_back(0);
}
B.push_back(vector<int>(B[0].size(), 0));
}
if (B[0].size() % 2 == 1)
{
for (size_t j = 0; j < B.size(); ++j)
B[j].push_back(0);
}
vector<vector<int>> S1(B.size()/2, vector<int>(B[0].size()/2, 0));
matrixSub(B, B[0].size()/2, 0, B[0].size() - 1, B.size()/2 - 1, B, B[0].size()/2, B.size()/2, S1);
vector<vector<int>> S2(A.size()/2, vector<int>(B.size()/2, 0));
matrixAdd(A, 0, 0, B.size()/2 - 1, A.size()/2 - 1, A, B.size()/2, 0, S2);
vector<vector<int>> S3(A.size() / 2, vector<int>(B.size() / 2, 0));
matrixAdd(A, 0, A.size() / 2, B.size() / 2 - 1, A.size() - 1, A, B.size() / 2, A.size() / 2, S3);
vector<vector<int>> S4(B.size() / 2, vector<int>(B[0].size() / 2, 0));
matrixSub(B, 0, B.size()/2, B[0].size()/2 - 1, B.size() - 1, B, 0, 0, S4);
vector<vector<int>> S5(A.size() / 2, vector<int>(B.size() / 2, 0));
matrixAdd(A, 0, 0, B.size()/2 - 1, A.size()/2 - 1, A, B.size()/2, A.size()/2, S5);
vector<vector<int>> S6(B.size() / 2, vector<int>(B[0].size() / 2, 0));
matrixAdd(B, 0, 0, B[0].size()/2 - 1, B.size()/2 - 1, B, B[0].size()/2, B.size()/2, S6);
vector<vector<int>> S7(A.size() / 2, vector<int>(B.size() / 2, 0));
matrixSub(A, B.size()/2, 0, B.size()-1, A.size()/2 -1, A, B.size()/2, A.size()/2, S7);
vector<vector<int>> S8(B.size() / 2, vector<int>(B[0].size() / 2, 0));
matrixAdd(B, 0, B.size() / 2, B[0].size() / 2 - 1, B.size() - 1, B, B[0].size() / 2, B.size() / 2, S8);
vector<vector<int>> S9(A.size() / 2, vector<int>(B.size() / 2, 0));
matrixSub(A, 0, 0, B.size() / 2 - 1, A.size() / 2 - 1, A, 0, A.size() / 2, S9);
vector<vector<int>> S10(B.size() / 2, vector<int>(B[0].size() / 2, 0));
matrixAdd(B, 0, 0, B[0].size() / 2 - 1, B.size() / 2 - 1, B, B[0].size() / 2, 0, S10);
vector<vector<int>> A11(A.size()/2, vector<int>(B.size()/2, 0));
extract(A, 0, 0, B.size() / 2 - 1, A.size() / 2 - 1, A11);
vector<vector<int>> P1(A.size() / 2, vector<int>(B[0].size() / 2, 0));
Strassen(A11, S1, P1);
vector<vector<int>> B22(B.size() / 2, vector<int>(B[0].size() / 2, 0));
extract(B, B[0].size()/2, B.size()/2, B[0].size() - 1, B.size() - 1, B22);
vector<vector<int>> P2(A.size() / 2, vector<int>(B[0].size() / 2, 0));
Strassen(S2, B22, P2);
vector<vector<int>> B11(B.size() / 2, vector<int>(B[0].size() / 2, 0));
extract(B, 0, 0, B[0].size() / 2 - 1, B.size() / 2 - 1, B11);
vector<vector<int>> P3(A.size() / 2, vector<int>(B[0].size() / 2, 0));
Strassen(S3, B11, P3);
vector<vector<int>> A22(A.size() / 2, vector<int>(B.size() / 2, 0));
extract(A, B.size() / 2, A.size() / 2, B.size() - 1, A.size() - 1, A22);
vector<vector<int>> P4(A.size() / 2, vector<int>(B[0].size() / 2, 0));
Strassen(A22, S4, P4);
vector<vector<int>> P5(A.size() / 2, vector<int>(B[0].size() / 2, 0));
Strassen(S5, S6, P5);
vector<vector<int>> P6(A.size() / 2, vector<int>(B[0].size() / 2, 0));
Strassen(S7, S8, P6);
vector<vector<int>> P7(A.size() / 2, vector<int>(B[0].size() / 2, 0));
Strassen(S9, S10, P7);
vector<vector<int>> C11(A.size() / 2, vector<int>(B[0].size() / 2, 0));
matrixAdd(P5, 0, 0, P5[0].size() - 1, P5.size() - 1, P4, 0, 0, C11);
matrixSub(C11, 0, 0, C11[0].size() - 1, C11.size() - 1, P2, 0, 0, C11);
matrixAdd(C11, 0, 0, C11[0].size() - 1, C11.size() - 1, P6, 0, 0, C11);
add(P, 0, 0, C11[0].size() - 1, C11.size() - 1, C11);
vector<vector<int>> C12(A.size() / 2, vector<int>(B[0].size() / 2, 0));
matrixAdd(P1, 0, 0, P1[0].size() - 1, P1.size() - 1, P2, 0, 0, C12);
add(P, C11[0].size(), 0, P[0].size() - 1, C11.size() - 1, C12);
vector<vector<int>> C21(A.size() / 2, vector<int>(B[0].size() / 2, 0));
matrixAdd(P3, 0, 0, P3[0].size() - 1, P3.size() - 1, P4, 0, 0, C21);
add(P, 0, C11.size(), C11[0].size() - 1, P.size() - 1, C21);
vector<vector<int>> C22(A.size() / 2, vector<int>(B[0].size() / 2, 0));
matrixAdd(P5, 0, 0, P5[0].size() - 1, P5.size() - 1, P1, 0, 0, C22);
matrixSub(C22, 0, 0, C22[0].size() - 1, C22.size() - 1, P3, 0, 0, C22);
matrixSub(C22, 0, 0, C22[0].size() - 1, C22.size() - 1, P7, 0, 0, C22);
add(P, C11[0].size(), C11.size(), P[0].size() - 1, P.size() - 1, C22);
}
void matrixMultiply(vector<vector<int>>& left, vector<vector<int>>& right, vector<vector<int>>& result)
{
for (size_t i = 0; i < left.size(); ++i)
{
for (size_t j = 0; j < right[0].size(); ++j)
{
for (size_t k = 0; k < right.size(); ++k)
result[i][j] += left[i][k] * right[k][j];
}
}
}
int main()
{
for (int NA = 1; NA <= 10; ++NA)
{
for (int MB = 1; MB <= 10; ++MB)
{
for (int MA = 1; MA <= 10; ++MA)
{
int NB = MA;
vector<int> testA(NA * MA, 0);
for (int i = 0; i < NA * MA; ++i)
testA[i] = i + 1;
shuffle(testA.begin(), testA.end(), default_random_engine());
vector<vector<int>> A(NA, vector<int>(MA, 0));
size_t k = 0;
for (size_t i = 0; i < A.size(); ++i)
{
for (size_t j = 0; j < A[i].size(); ++j)
{
A[i][j] = testA[k++];
}
}
vector<int> testB(NB * MB, 0);
for (int i = 0; i < NB * MB; ++i)
testB[i] = i + 1;
shuffle(testB.begin(), testB.end(), default_random_engine());
vector<vector<int>> B(NB, vector<int>(MB, 0));
k = 0;
for (size_t i = 0; i < B.size(); ++i)
{
for (size_t j = 0; j < B[i].size(); ++j)
{
B[i][j] = testB[k++];
}
}
cout << "相乘的两个矩阵为" << endl;
for (size_t i = 0; i < A.size(); ++i)
{
for (size_t j = 0; j < A[i].size(); ++j)
{
cout << A[i][j] << " ";
}
cout << endl;
}
cout << endl;
for (size_t i = 0; i < B.size(); ++i)
{
for (size_t j = 0; j < B[i].size(); ++j)
{
cout << B[i][j] << " ";
}
cout << endl;
}
vector<vector<int>> result1(NA, vector<int>(MB, 0));
vector<vector<int>> result2(NA, vector<int>(MB, 0));
vector<vector<int>> _B = B;
vector<vector<int>> _A = A;
Strassen(A, B, result1);
matrixMultiply(_A, _B, result2);
for (size_t i = 0; i < result1.size(); ++i)
{
for (size_t j = 0; j < result1[i].size(); ++j)
{
if (result1[i][j] != result2[i][j])
{
cout << "矩阵相乘结果错误!" << endl;
exit(-1);
}
}
}
cout << "矩阵相乘结果正确,乘积为" << endl;
for (size_t i = 0; i < result1.size(); ++i)
{
for (size_t j = 0; j < result1[i].size(); ++j)
{
cout << result1[i][j] << " ";
}
cout << endl;
}
}
}
}
/* const int NA = 3;
const int MA = 4;
const int NB = 4;
const int MB = 5;
vector<int> testA(NA * MA, 0);
for (int i = 0; i < NA * MA; ++i)
testA[i] = i + 1;
shuffle(testA.begin(), testA.end(), default_random_engine());
vector<vector<int>> A(NA, vector<int>(MA, 0));
size_t k = 0;
for (size_t i = 0; i < A.size(); ++i)
{
for (size_t j = 0; j < A[i].size(); ++j)
{
A[i][j] = testA[k++];
}
}
vector<int> testB(NB * MB, 0);
for (int i = 0; i < NB * MB; ++i)
testB[i] = i + 1;
shuffle(testB.begin(), testB.end(), default_random_engine());
vector<vector<int>> B(NB, vector<int>(MB, 0));
k = 0;
for (size_t i = 0; i < B.size(); ++i)
{
for (size_t j = 0; j < B[i].size(); ++j)
{
B[i][j] = testB[k++];
}
}
cout << "相乘的两个矩阵为" << endl;
for (size_t i = 0; i < A.size(); ++i)
{
for (size_t j = 0; j < A[i].size(); ++j)
{
cout << A[i][j] << " ";
}
cout << endl;
}
cout << endl;
for (size_t i = 0; i < B.size(); ++i)
{
for (size_t j = 0; j < B[i].size(); ++j)
{
cout << B[i][j] << " ";
}
cout << endl;
}
vector<vector<int>> result1(NA, vector<int>(MB, 0));
vector<vector<int>> result2(NA, vector<int>(MB, 0));
vector<vector<int>> _B = B;
vector<vector<int>> _A = A;
Strassen(A, B, result1);
matrixMultiply(_A, _B, result2);
for (size_t i = 0; i < result1.size(); ++i)
{
for (size_t j = 0; j < result1[i].size(); ++j)
{
if (result1[i][j] != result2[i][j])
{
cout << "矩阵相乘结果错误!" << endl;
exit(-1);
}
}
}
cout << "矩阵相乘结果正确,乘积为" << endl;
for (size_t i = 0; i < result1.size(); ++i)
{
for (size_t j = 0; j < result1[i].size(); ++j)
{
cout << result1[i][j] << " ";
}
cout << endl;
}*/
return 0;
}