问题描述:
矩阵乘法的优化
矩阵乘法是线性代数中的基本运算之一,但普通的矩阵乘法算法在计算复杂度上较高。斯特拉森矩阵乘法算法是一种通过分治策略减少乘法次数的算法,能够显著提高矩阵乘法的效率。
问题分析:
在矩阵乘法中,A矩阵和B矩阵可以做乘法运算必须满足A矩阵的列的数量等于B矩阵的行的数量。
运算规则:A的每一行中的数字对应乘以B的每一列的数字把结果相加起来。
普通的矩阵乘法需要进行三重循环,时间复杂度为O(n^3),而斯特拉森矩阵乘法通过将原始矩阵分解成较小的子矩阵,然后利用这些子矩阵进行计算,来减少重复计算和乘法次数。
我们不断递归对矩阵做分治拆解,简单来说就是把一个大块矩阵不断做斯特拉森拆解,考虑速度在斯特拉森函数内部只做加减法,当拆解到一定大小的矩阵时,将每个小矩阵投入普通的矩阵乘法运算,至于这个 ‘一定大小’ 在下面的代码中定义为32,很有可能在64达到最快速度。
需要知道的是,斯特拉森算法只是对矩阵分治的算法而不是单独的乘法算法,分治完成时最后使用的还是普通矩阵乘法,在阶数小于等于32,时普通的矩乘法会有更快的速度,而随着矩阵的阶不断增加,斯特拉森可以提供更快的速度
算法设计:
1. 如果两个矩阵A和B的维度相同,并且是2的幂次方,如2x2,4x4,8x8等,则可以按照斯特拉森矩阵乘法算法进行计算。
2. 将原始的A和B矩阵分别分解成四块小矩阵,记作A11, A12, A21, A22, B11, B12, B21, B22。
3. 对这些小矩阵进行七次乘法运算和十次加法运算,得到四个结果矩阵C11, C12, C21, C22。
4. 将这四个结果矩阵按照一定规则组合起来,得到最终的结果矩阵C。
进行如下矩阵分割:
分割之后我们先进行如下一些预处理运算
S1 = B12- B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5= A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10 = B11 + B12
然后我们根据S1~S2的值,进行7次乘法运算,如下:
P 1= A11・S1
P2 = S2·B22
P3 = S3 ・ B11
P4=A22. S4
P5= S5. S6
P6 = S7. S8
P7 = S9 ・ S10
然后再通过简单的加法运算,就可以得到结果矩阵的四个子矩阵:
C11 = P5+P4 - P2 + P6
C12 =P1+P2
C21 = P3 + P4
C22 = P5 + P1 - P3- P7
最后将子矩阵合并即可
实现代码:
#include <iostream>
#include <vector>
// 功能:矩阵相加
std::vector<std::vector<int>> matrixAdd(const std::vector<std::vector<int>>& A, const std::vector<std::vector<int>>& B) {
int n = A.size();
int m = A[0].size();
std::vector<std::vector<int>> C(n, std::vector<int>(m));
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
// 功能:矩阵相减
std::vector<std::vector<int>> matrixSub(const std::vector<std::vector<int>>& A, const std::vector<std::vector<int>>& B) {
int n = A.size();
int m = A[0].size();
std::vector<std::vector<int>> C(n, std::vector<int>(m));
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
// 功能:矩阵斯特拉森算法
std::vector<std::vector<int>> strassen(const std::vector<std::vector<int>>& A, const std::vector<std::vector<int>>& B) {
int n = A.size();
if (n == 1) {
std::vector<std::vector<int>> C(1, std::vector<int>(1));
C[0][0] = A[0][0] * B[0][0];
return C;
} else {
int halfSize = n / 2;
std::vector<std::vector<int>> A11(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> A12(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> A21(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> A22(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> B11(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> B12(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> B21(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> B22(halfSize, std::vector<int>(halfSize));
// 将矩阵 A 和 B 拆分成四个子矩阵
for (int i = 0; i < halfSize; i++) {
for (int j = 0; j < halfSize; j++) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][j + halfSize];
A21[i][j] = A[i + halfSize][j];
A22[i][j] = A[i + halfSize][j + halfSize];
B11[i][j] = B[i][j];
B12[i][j] = B[i][j + halfSize];
B21[i][j] = B[i + halfSize][j];
B22[i][j] = B[i + halfSize][j + halfSize];
}
}
// 计算七个矩阵乘法的递归计算
std::vector<std::vector<int>> S1 = matrixSub(B12, B22);
std::vector<std::vector<int>> S2 = matrixAdd(A11, A12);
std::vector<std::vector<int>> S3 = matrixAdd(A21, A22);
std::vector<std::vector<int>> S4 = matrixSub(B21, B11);
std::vector<std::vector<int>> S5 = matrixAdd(A11, A22);
std::vector<std::vector<int>> S6 = matrixAdd(B11, B22);
std::vector<std::vector<int>> S7 = matrixSub(A12, A22);
std::vector<std::vector<int>> S8 = matrixAdd(B21, B22);
std::vector<std::vector<int>> S9 = matrixSub(A11, A21);
std::vector<std::vector<int>> S10 = matrixAdd(B11, B12);
std::vector<std::vector<int>> P1 = strassen(A11, S1);
std::vector<std::vector<int>> P2 = strassen(S2, B22);
std::vector<std::vector<int>> P3 = strassen(S3, B11);
std::vector<std::vector<int>> P4 = strassen(A22, S4);
std::vector<std::vector<int>> P5 = strassen(S5, S6);
std::vector<std::vector<int>> P6 = strassen(S7, S8);
std::vector<std::vector<int>> P7 = strassen(S9, S10);
std::vector<std::vector<int>> C11 = matrixAdd(matrixSub(matrixAdd(P5, P4), P2), P6);
std::vector<std::vector<int>> C12 = matrixAdd(P1, P2);
std::vector<std::vector<int>> C21 = matrixAdd(P3, P4);
std::vector<std::vector<int>> C22 = matrixSub(matrixSub(matrixAdd(P5, P1), P3), P7);
std::vector<std::vector<int>> C(n, std::vector<int>(n));
for (int i = 0; i < halfSize; i++) {
for (int j = 0; j < halfSize; j++) {
C[i][j] = C11[i][j];
C[i][j + halfSize] = C12[i][j];
C[i + halfSize][j] = C21[i][j];
C[i + halfSize][j + halfSize] = C22[i][j];
}
}
return C;
}
}
int main() {
std::vector<std::vector<int>> A = {{1, 2}, {3, 4}};
std::vector<std::vector<int>> B = {{5, 6}, {7, 8}};
std::vector<std::vector<int>> C = strassen(A, B);
for (const auto& row : C) {
for (int value : row) {
std::cout << value << " ";
}
std::cout << std::endl;
}
return 0;
为了方便提交作业,没有设置输入的代码,有需要可以在main函数自己改矩阵内容,或者改输入
2*2运行结果:
3×3运行结果: