Strassen算法
时间算法度: Θ ( n l o g 2 7 ) ≈ Θ ( n 2.8074 ) Θ\left(n^{log_{2}7}\right)≈Θ\left(n^{2.8074}\right) Θ(nlog27)≈Θ(n2.8074)
该算法主旨:
S
t
r
a
s
s
e
n
Strassen
Strassen算法的核心是让递归树不那么茂盛一点,即只递归
7
7
7次而不是
8
8
8次
n
/
2
×
n
/
2
n/2×n/2
n/2×n/2矩阵的乘法。
令 A A A, B B B为环 R R R上的两个平方矩阵。我们想将矩阵乘积 C C C计算为:
C = A B A , B , C ∈ R 2 n × 2 n C=AB\ \ \ \ \ \ \ \ A,B,C∈R^{2^{n}×2^{n}} C=AB A,B,C∈R2n×2n
如果矩阵 A A A, B B B不是 2 n = 2 n 2^n = 2^n 2n=2n样式,我们将用 0 0 0填充缺少的行和列。
我们将A、B和C划分为大小相等的块矩阵:
A = [ A 1 , 1 A 1 , 2 A 2 , 1 A 2 , 2 ] A=\begin{bmatrix}A_{1,1} & A_{1,2} \\A_{2,1} & A_{2,2}\end{bmatrix} A=[A1,1A2,1A1,2A2,2] B = [ B 1 , 1 B 1 , 2 B 2 , 1 B 2 , 2 ] B=\begin{bmatrix}B_{1,1} & B_{1,2} \\B_{2,1} & B_{2,2}\end{bmatrix} B=[B1,1B2,1B1,2B2,2] C = [ C 1 , 1 C 1 , 2 C 2 , 1 C 2 , 2 ] C=\begin{bmatrix}C_{1,1} & C_{1,2} \\C_{2,1} & C_{2,2}\end{bmatrix} C=[C1,1C2,1C1,2C2,2]
根据朴素算法可以得出:
C 1 , 1 = A 1 , 1 × B 1 , 1 + A 1 , 2 × B 2 , 1 C_{1,1}=A_{1,1}×B_{1,1}+A_{1,2}×B_{2,1} C1,1=A1,1×B1,1+A1,2×B2,1
C 1 , 2 = A 1 , 1 × B 1 , 2 + A 1 , 2 × B 2 , 2 C_{1,2}=A_{1,1}×B_{1,2}+A_{1,2}×B_{2,2} C1,2=A1,1×B1,2+A1,2×B2,2
C 2 , 1 = A 2 , 1 × B 1 , 1 + A 2 , 2 × B 2 , 1 C_{2,1}=A_{2,1}×B_{1,1}+A_{2,2}×B_{2,1} C2,1=A2,1×B1,1+A2,2×B2,1
C 2 , 2 = A 2 , 1 × B 1 , 2 + A 2 , 2 × B 2 , 2 C_{2,2}=A_{2,1}×B_{1,2}+A_{2,2}×B_{2,2} C2,2=A2,1×B1,2+A2,2×B2,2
然而,我们可以看到,有了这个结构,我们依然没有减少乘法的数量。我们仍然需要 8 8 8个乘法来计算 C i , j C_{i,j} Ci,j矩阵,这与使用标准矩阵乘法时所需的乘法数相同。S t r a s s e n Strassen Strassen算法定义了新的矩阵:
M 1 = ( A 1 , 1 + A 2 , 2 ) ( B 1 , 1 + B 2 , 2 ) M_{1}=\left(A_{1,1}+A_{2,2}\right)\left(B_{1,1}+B_{2,2}\right) M1=(A1,1+A2,2)(B1,1+B2,2)
M 2 = ( A 2 , 1 + A 2 , 2 ) B 1 , 1 M_{2}=\left(A_{2,1}+A_{2,2}\right)B_{1,1} M2=(A2,1+A2,2)B1,1
M 3 = A 1 , 1 ( B 1 , 2 − B 2 , 2 ) M_{3}=A_{1,1}\left(B_{1,2}−B_{2,2}\right) M3=A1,1(B1,2−B2,2)
M 4 = A 2 , 2 ( B 2 , 1 − B 1 , 1 ) M_{4}=A_{2,2}\left(B_{2,1}−B_{1,1}\right) M4=A2,2(B2,1−B1,1)
M 5 = ( A 1 , 1 + A 1 , 2 ) B 2 , 2 M_{5}=\left(A_{1,1}+A_{1,2}\right)B_{2,2} M5=(A1,1+A1,2)B2,2
M 6 = ( A 2 , 1 − A 1 , 1 ) ( B 1 , 1 + B 1 , 2 ) M_{6}=\left(A_{2,1}−A_{1,1}\right)\left(B_{1,1}+B_{1,2}\right) M6=(A2,1−A1,1)(B1,1+B1,2)
M 7 = ( A 1 , 2 − A 2 , 2 ) ( B 2 , 1 + B 2 , 2 ) M_{7}=\left(A_{1,2}−A_{2,2}\right)\left(B_{2,1}+B_{2,2}\right) M7=(A1,2−A2,2)(B2,1+B2,2)
仅使用 7 7 7个乘法(每个 M k M_{k} Mk一个乘法),而不是 8 8 8。我们现在可以用 M k M_{k} Mk的不同组合进行加减法运算来表示 C i , j C_{i,j} Ci,j:
C 1 , 1 = M 1 + M 4 − M 5 + M 7 C_{1,1}=M_{1}+M_{4}-M_{5}+M_{7} C1,1=M1+M4−M5+M7
C 1 , 2 = M 3 + M 5 C_{1,2}=M_{3}+M_{5} C1,2=M3+M5
C 2 , 1 = M 2 + M 4 C_{2,1}=M_{2}+M_{4} C2,1=M2+M4
C 2 , 2 = M 1 − M 2 + M 3 + M 6 C_{2,2}=M_{1}-M_{2}+M_{3}+M_{6} C2,2=M1−M2+M3+M6
———来自维基百科
验证一下:
C
1
,
2
=
M
3
+
M
5
C_{1,2}=M3+M5
C1,2=M3+M5
=
A
1
,
1
(
B
1
,
2
−
B
2
,
2
)
+
(
A
1
,
1
+
A
1
,
2
)
B
2
,
2
\ =A_{1,1}\left(B_{1,2}−B_{2,2}\right)+\left(A_{1,1}+A_{1,2}\right)B_{2,2}
=A1,1(B1,2−B2,2)+(A1,1+A1,2)B2,2
=
A
1
,
1
×
B
1
,
2
−
A
1
,
1
×
B
2
,
2
+
A
1
,
1
×
B
2
,
2
+
A
1
,
2
×
B
2
,
2
\ =A_{1,1}×B_{1,2}-A_{1,1}×B_{2,2}+A_{1,1}×B_{2,2}+A_{1,2}×B_{2,2}
=A1,1×B1,2−A1,1×B2,2+A1,1×B2,2+A1,2×B2,2
=
A
1
,
1
×
B
1
,
2
+
A
1
,
2
×
B
2
,
2
\ =A_{1,1}×B_{1,2}+A_{1,2}×B_{2,2}
=A1,1×B1,2+A1,2×B2,2
伪代码:
// 下面的伪代码是参考的别人的,我自己实在懒得写了...
STRASSEN(A, B)
1 Length is the row or line of A and B
2 let C be a new (Length * Length) matrix
3 if Length == 1
4 C = A * B
5 else partition A,B,and C as in equations(4,9)
6 S1 = B12 - B22
7 S2 = A11 - A12
8 S3 = A21 + A22
9 S4 = B21 - B11
10 S5 = A11 + A22
11 S6 = B11 + B22
12 S7 = A12 - A22
13 S8 = B21 + B22
14 S9 = A11 - A21
15 S10 = B11 + B12
16 P1 = STRASSEN(A11, S1)
17 P2 = STRASSEN(A11, B22)
18 P3 = STRASSEN(S3, B11)
19 P4 = Strassen(A22, S4)
20 P5 = STRASSEN(S5, S6)
21 P6 = STRASSEN(S7, S8)
22 P7 = STRASSEN(S9, S10)
23 C11 = P5 + P4 - P2 + P6
24 C12 = P1 + P2
25 C21 = P3 + P4
26 C22 = P5 + P1 - P3 - P7
27 return C
C++代码:
// 以下代码参考自 https://blog.csdn.net/zhuangxiaobin/article/details/36476769
/* ————————————————————————————————————————————————————————————
*
* 矩阵加法函数,时间算法度:O(n^2)
*
* ———————————————————————————————————————————————————————————— */
void Add(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
for (int i = 0; i < length; i++)
for (int j = 0; j < length; j++)
Matrix_C[i][j] = Matrix_A[i][j] + Matrix_B[i][j];
}
/* ————————————————————————————————————————————————————————————
*
* 矩阵减法函数,时间算法度:O(n^2)
*
* ———————————————————————————————————————————————————————————— */
void Subtract(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
for (int i = 0; i < length; i++)
for (int j = 0; j < length; j++)
Matrix_C[i][j] = Matrix_A[i][j] - Matrix_B[i][j];
}
/* ————————————————————————————————————————————————————————————————————————————————————————————————————————————————————————
* Strassen算法(矩阵乘法函数)
*
* 参数:
* Matrix_A为一个指向指针的指针,它应指向一个指针数组,指针数组中的元素又指向另一个普通数组,形成一个二维数组。
* Matrix_B和Matrix_C同上。
* 第4个参数(length)为矩阵的长或宽。
* 注意:矩阵的长于宽应相等,如若不相等,则传入最长的那个边,函数会自动填充至length * length。
* ———————————————————————————————————————————————————————————————————————————————————————————————————————————————————————— */
void STRASSEN_ALGORITHM(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
if (length == 1)
Matrix_C[0][0] = Matrix_A[0][0] * Matrix_B[0][0];
else
{
int Middle = length / 2;
/* ————————————————————————————————————————————————————————————
*
* 因为我们传递二维数组得时候,必须指明其一维数组得下标,这
* 会导致我们得递归 length / 2 得规模时无法进行传递参数。
*
* 所以,我们将要声明指向指针得指针,他们将指向指针数组,指针
* 数组内得元素为指针,所以在将元素指向一个一维数组,用来模拟
* 二维数组。
*
* ———————————————————————————————————————————————————————————— */
int **Matrix_A_11 = new int *[Middle];
int **Matrix_A_12 = new int *[Middle];
int **Matrix_A_21 = new int *[Middle];
int **Matrix_A_22 = new int *[Middle];
int **Matrix_B_11 = new int *[Middle];
int **Matrix_B_12 = new int *[Middle];
int **Matrix_B_21 = new int *[Middle];
int **Matrix_B_22 = new int *[Middle];
int **Matrix_C_11 = new int *[Middle];
int **Matrix_C_12 = new int *[Middle];
int **Matrix_C_21 = new int *[Middle];
int **Matrix_C_22 = new int *[Middle];
int **M1 = new int *[Middle];
int **M2 = new int *[Middle];
int **M3 = new int *[Middle];
int **M4 = new int *[Middle];
int **M5 = new int *[Middle];
int **M6 = new int *[Middle];
int **M7 = new int *[Middle];
int **Result_1 = new int *[Middle];
int **Result_2 = new int *[Middle];
for (int i = 0; i < Middle; i++)
{
Matrix_A_11[i] = new int [Middle];
Matrix_A_12[i] = new int [Middle];
Matrix_A_21[i] = new int [Middle];
Matrix_A_22[i] = new int [Middle];
Matrix_B_11[i] = new int [Middle];
Matrix_B_12[i] = new int [Middle];
Matrix_B_21[i] = new int [Middle];
Matrix_B_22[i] = new int [Middle];
Matrix_C_11[i] = new int [Middle];
Matrix_C_12[i] = new int [Middle];
Matrix_C_21[i] = new int [Middle];
Matrix_C_22[i] = new int [Middle];
M1[i] = new int [Middle];
M2[i] = new int [Middle];
M3[i] = new int [Middle];
M4[i] = new int [Middle];
M5[i] = new int [Middle];
M6[i] = new int [Middle];
M7[i] = new int [Middle];
Result_1[i] = new int [Middle];
Result_2[i] = new int [Middle];
}
/* ————————————————————————————————————————————————————————————————————————————————————————————————————————————
*
* 现在,我们将要给这些 "二维数组" 传递特地得值:
*
* Matrix_A_11将获取Matrix_A中 (0 ~ Middle - 1) * (0 ~ Middle - 1) 的数组下标的元素;
*
* Matrix_A_12将获取Matrix_A中 (0 ~ Middle - 1) * (Middle ~ length - 1) 的数组下标的元素;
*
* Matrix_A_21将获取Matrix_A中 (Middle ~ length - 1) * (0 ~ Middle - 1) 的数组下标的元素;
*
* Matrix_A_22将获取Matrix_A中 (Middle ~ length - 1) * (Middle ~ length - 1) 的数组下标的元素。
* * 获取Matrix_B数组元素时,同理。
*
* ———————————————————————————————————————————————————————————————————————————————————————————————————————————— */
for (int i = 0; i < Middle; i++)
{
for (int j = 0; j < Middle; j++)
{
Matrix_A_11[i][j] = Matrix_A[i][j];
Matrix_A_12[i][j] = Matrix_A[i][j + Middle];
Matrix_A_21[i][j] = Matrix_A[i + Middle][j];
Matrix_A_22[i][j] = Matrix_A[i + Middle][j + Middle];
Matrix_B_11[i][j] = Matrix_B[i][j];
Matrix_B_12[i][j] = Matrix_B[i][j + Middle];
Matrix_B_21[i][j] = Matrix_B[i + Middle][j];
Matrix_B_22[i][j] = Matrix_B[i + Middle][j + Middle];
}
}
/* ————————————————————————————————————————————————————————————
*
* 递归进行矩阵的乘法运算
*
* 在递归之前,我们会调用Add和Subtract两个函数进行矩阵之间必
* 要的加减法运算,并用Result_1和Result_2临时存储其结果。
*
* ———————————————————————————————————————————————————————————— */
// M1
Add(Matrix_A_11, Matrix_A_22, Result_1, Middle);
Add(Matrix_B_11, Matrix_B_22, Result_2, Middle);
STRASSEN_ALGORITHM(Result_1, Result_2, M1, Middle);
// M2
Add(Matrix_A_21, Matrix_A_22, Result_1, Middle);
STRASSEN_ALGORITHM(Result_1, Matrix_B_11, M2, Middle);
// M3
Subtract(Matrix_B_12, Matrix_B_22, Result_1, Middle);
STRASSEN_ALGORITHM(Matrix_A_11, Result_1, M3, Middle);
// M4
Subtract(Matrix_B_21, Matrix_B_11, Result_1, Middle);
STRASSEN_ALGORITHM(Matrix_A_22, Result_1, M4, Middle);
// M5
Add(Matrix_A_11, Matrix_A_12, Result_1, Middle);
STRASSEN_ALGORITHM(Result_1, Matrix_B_22, M5, Middle);
// M6
Subtract(Matrix_A_21, Matrix_A_11, Result_1, Middle);
Add(Matrix_B_11, Matrix_B_12, Result_2, Middle);
STRASSEN_ALGORITHM(Result_1, Result_2, M6, Middle);
// M7
Subtract(Matrix_A_12, Matrix_A_22, Result_1, Middle);
Add(Matrix_B_21, Matrix_B_22, Result_2, Middle);
STRASSEN_ALGORITHM(Result_1, Result_2, M7, Middle);
/* ————————————————————————————————————————————————————————————
*
* 现在我们将按照Strassen算法的步骤,将前面递归计算出来的矩
* 阵M1...M7进行特定的加减法运算,从而得出Matrix_C_11, Matrix_C_12
* Matrix_C_21, Matrix_C_22。
*
* 我们会调用Add和Subtract两个函数进行矩阵之间必要的加减法运
* 算,并用Result_1和Result_2临时存储其结果。
*
* ———————————————————————————————————————————————————————————— */
// Matrix_C_11
Add(M1, M4, Result_1, Middle);
Subtract(Result_1, M5, Result_2, Middle);
Add(Result_2, M7, Matrix_C_11, Middle);
// Matrix_C_12
Add(M3, M5, Matrix_C_12, Middle);
// Matrix_C_21
Add(M2, M4, Matrix_C_21, Middle);
// Matrix_C_22
Subtract(M1, M2, Result_1, Middle);
Add(Result_1, M3, Result_2, Middle);
Add(Result_2, M6, Matrix_C_22, Middle);
/* ————————————————————————————————————————————————————————————
*
* 现在我们需要将前面的四个小矩阵"拼"回一个大矩阵
*
* ———————————————————————————————————————————————————————————— */
for (int i = 0; i < Middle; i++)
{
for (int j = 0; j < Middle; j++)
{
Matrix_C[i][j] = Matrix_C_11[i][j];
Matrix_C[i][j + Middle] = Matrix_C_12[i][j];
Matrix_C[i + Middle][j] = Matrix_C_21[i][j];
Matrix_C[i + Middle][j + Middle] = Matrix_C_22[i][j];
}
}
/* ————————————————————————————————————————————————————————————
*
* 最后释放掉申请的动态内存
*
* ———————————————————————————————————————————————————————————— */
for (int i = 0; i < Middle; i++)
{
delete[] Matrix_A_11[i];
delete[] Matrix_A_12[i];
delete[] Matrix_A_21[i];
delete[] Matrix_A_22[i];
delete[] Matrix_B_11[i];
delete[] Matrix_B_12[i];
delete[] Matrix_B_21[i];
delete[] Matrix_B_22[i];
delete[] Matrix_C_11[i];
delete[] Matrix_C_12[i];
delete[] Matrix_C_21[i];
delete[] Matrix_C_22[i];
delete[] M1[i];
delete[] M2[i];
delete[] M3[i];
delete[] M4[i];
delete[] M5[i];
delete[] M6[i];
delete[] M7[i];
delete[] Result_1[i];
delete[] Result_2[i];
}
delete[] Matrix_A_11;
delete[] Matrix_A_12;
delete[] Matrix_A_21;
delete[] Matrix_A_22;
delete[] Matrix_B_11;
delete[] Matrix_B_12;
delete[] Matrix_B_21;
delete[] Matrix_B_22;
delete[] Matrix_C_11;
delete[] Matrix_C_12;
delete[] Matrix_C_21;
delete[] Matrix_C_22;
delete[] M1;
delete[] M2;
delete[] M3;
delete[] M4;
delete[] M5;
delete[] M6;
delete[] M7;
delete[] Result_1;
delete[] Result_2;
}
}
自己简化了一丢丢的代码:
/* ————————————————————————————————————————————————————————————
* 自己改进了一丢丢的Strassen算法:
*
* 尽量的去减小前面Strassen算法的空间复杂度
*
* 这里主要是利用其数组下标和指针的特性,将原本需要动态申请的A, B,
* C的12个(n / 2) * (n / 2)的小矩阵,改成了规模为12个n / 2的一
* 维数组。
*
* ———————————————————————————————————————————————————————————— */
void STRASSEN_ALGORITHM(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
if (length == 1)
Matrix_C[0][0] = Matrix_A[0][0] * Matrix_B[0][0];
else
{
int Middle = length / 2;
int **Matrix_A_11 = new int *[Middle];
int **Matrix_A_12 = new int *[Middle];
int **Matrix_A_21 = new int *[Middle];
int **Matrix_A_22 = new int *[Middle];
int **Matrix_B_11 = new int *[Middle];
int **Matrix_B_12 = new int *[Middle];
int **Matrix_B_21 = new int *[Middle];
int **Matrix_B_22 = new int *[Middle];
int **Matrix_C_11 = new int *[Middle];
int **Matrix_C_12 = new int *[Middle];
int **Matrix_C_21 = new int *[Middle];
int **Matrix_C_22 = new int *[Middle];
int **M1 = new int *[Middle];
int **M2 = new int *[Middle];
int **M3 = new int *[Middle];
int **M4 = new int *[Middle];
int **M5 = new int *[Middle];
int **M6 = new int *[Middle];
int **M7 = new int *[Middle];
int **Result_1 = new int *[Middle];
int **Result_2 = new int *[Middle];
for (int i = 0; i < Middle; i++)
{
M1[i] = new int [Middle];
M2[i] = new int [Middle];
M3[i] = new int [Middle];
M4[i] = new int [Middle];
M5[i] = new int [Middle];
M6[i] = new int [Middle];
M7[i] = new int [Middle];
Result_1[i] = new int [Middle];
Result_2[i] = new int [Middle];
}
for (int i = 0; i < Middle; i++)
{
Matrix_A_11[i] = Matrix_A[i];
Matrix_A_12[i] = Matrix_A[i] + Middle;
Matrix_A_21[i] = Matrix_A[i + Middle];
Matrix_A_22[i] = Matrix_A[i + Middle] + Middle;
Matrix_B_11[i] = Matrix_B[i];
Matrix_B_12[i] = Matrix_B[i] + Middle;
Matrix_B_21[i] = Matrix_B[i + Middle];
Matrix_B_22[i] = Matrix_B[i + Middle] + Middle;
Matrix_C_11[i] = Matrix_C[i];
Matrix_C_12[i] = Matrix_C[i] + Middle;
Matrix_C_21[i] = Matrix_C[i + Middle];
Matrix_C_22[i] = Matrix_C[i + Middle] + Middle;
}
/* ————————————————————————————————————————————————————————————
*
* 递归进行矩阵的乘法运算
*
* 在递归之前,我们会调用Add和Subtract两个函数进行矩阵之间必
* 要的加减法运算,并用Result_1和Result_2临时存储其结果。
*
* ———————————————————————————————————————————————————————————— */
// M1
Add(Matrix_A_11, Matrix_A_22, Result_1, Middle);
Add(Matrix_B_11, Matrix_B_22, Result_2, Middle);
STRASSEN_ALGORITHM(Result_1, Result_2, M1, Middle);
// M2
Add(Matrix_A_21, Matrix_A_22, Result_1, Middle);
STRASSEN_ALGORITHM(Result_1, Matrix_B_11, M2, Middle);
// M3
Subtract(Matrix_B_12, Matrix_B_22, Result_1, Middle);
STRASSEN_ALGORITHM(Matrix_A_11, Result_1, M3, Middle);
// M4
Subtract(Matrix_B_21, Matrix_B_11, Result_1, Middle);
STRASSEN_ALGORITHM(Matrix_A_22, Result_1, M4, Middle);
// M5
Add(Matrix_A_11, Matrix_A_12, Result_1, Middle);
STRASSEN_ALGORITHM(Result_1, Matrix_B_22, M5, Middle);
// M6
Subtract(Matrix_A_21, Matrix_A_11, Result_1, Middle);
Add(Matrix_B_11, Matrix_B_12, Result_2, Middle);
STRASSEN_ALGORITHM(Result_1, Result_2, M6, Middle);
// M7
Subtract(Matrix_A_12, Matrix_A_22, Result_1, Middle);
Add(Matrix_B_21, Matrix_B_22, Result_2, Middle);
STRASSEN_ALGORITHM(Result_1, Result_2, M7, Middle);
/* ————————————————————————————————————————————————————————————
*
* 现在我们将按照Strassen算法的步骤,将前面递归计算出来的矩
* 阵M1...M7进行特定的加减法运算,从而得出Matrix_C_11, Matrix_C_12
* Matrix_C_21, Matrix_C_22。
*
* 我们会调用Add和Subtract两个函数进行矩阵之间必要的加减法运
* 算,并用Result_1和Result_2临时存储其结果。
*
* ———————————————————————————————————————————————————————————— */
// Matrix_C_11
Add(M1, M4, Result_1, Middle);
Subtract(Result_1, M5, Result_2, Middle);
Add(Result_2, M7, Matrix_C_11, Middle);
// Matrix_C_12
Add(M3, M5, Matrix_C_12, Middle);
// Matrix_C_21
Add(M2, M4, Matrix_C_21, Middle);
// Matrix_C_22
Subtract(M1, M2, Result_1, Middle);
Add(Result_1, M3, Result_2, Middle);
Add(Result_2, M6, Matrix_C_22, Middle);
/* ————————————————————————————————————————————————————————————
*
* 最后释放掉申请的动态内存
*
* ———————————————————————————————————————————————————————————— */
for (int i = 0; i < Middle; i++)
{
delete[] M1[i];
delete[] M2[i];
delete[] M3[i];
delete[] M4[i];
delete[] M5[i];
delete[] M6[i];
delete[] M7[i];
delete[] Result_1[i];
delete[] Result_2[i];
}
delete[] Matrix_A_11;
delete[] Matrix_A_12;
delete[] Matrix_A_21;
delete[] Matrix_A_22;
delete[] Matrix_B_11;
delete[] Matrix_B_12;
delete[] Matrix_B_21;
delete[] Matrix_B_22;
delete[] Matrix_C_11;
delete[] Matrix_C_12;
delete[] Matrix_C_21;
delete[] Matrix_C_22;
delete[] M1;
delete[] M2;
delete[] M3;
delete[] M4;
delete[] M5;
delete[] M6;
delete[] M7;
delete[] Result_1;
delete[] Result_2;
}
}
测试代码:
#include <iostream>
constexpr int N = 4;
// 声明创建矩阵函数,返回指向指针的指针
int ** Creating_Matrix(int length);
// 声明销毁矩阵函数
void Delete_Matrix(int **Matrix, int length);
// 声明矩阵加法函数
void Add(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length);
// 声明矩阵加法函数
void Subtract(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length);
// 声明Strassen算法(矩阵乘法函数)
void STRASSEN_ALGORITHM(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length);
int main(void)
{
int **A, **B, **C;
A = Creating_Matrix(N);
B = Creating_Matrix(N);
C = Creating_Matrix(N);
int Number = 1;
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
{
A[i][j] = Number;
B[i][j] = Number + N * N;
}
}
std::cout << "矩阵A: " << std::endl;
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
std::cout << A[i][j] << "\t\t";
std::cout << std::endl;
}
std::cout << std::endl;
std::cout << "矩阵B: " << std::endl;
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
std::cout << B[i][j] << "\t\t";
std::cout << std::endl;
}
std::cout << std::endl;
STRASSEN_ALGORITHM(A, B, C, N);
std::cout << "矩阵C = A * B: " << std::endl;
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
std::cout << C[i][j] << "\t\t";
std::cout << std::endl;
}
Delete_Matrix(A, N);
Delete_Matrix(B, N);
Delete_Matrix(C, N);
return 0;
}
// 定义创建矩阵函数
int ** Creating_Matrix(int length)
{
int **Temporary = new int *[length];
for (int i = 0; i < length; i++)
Temporary[i] = new int[length];
return Temporary;
}
// 声明销毁矩阵函数
void Delete_Matrix(int **Matrix, int length)
{
for (int i = 0; i < length; i++)
delete[] Matrix[i];
delete[] Matrix;
}
// 定义矩阵加法函数
void Add(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
for (int i = 0; i < length; i++)
for (int j = 0; j < length; j++)
Matrix_C[i][j] = Matrix_A[i][j] + Matrix_B[i][j];
}
// 定义矩阵减法函数
void Subtract(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
for (int i = 0; i < length; i++)
for (int j = 0; j < length; j++)
Matrix_C[i][j] = Matrix_A[i][j] - Matrix_B[i][j];
}
// 定义Strassen算法(矩阵乘法函数)
void STRASSEN_ALGORITHM(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
if (length == 1)
Matrix_C[0][0] = Matrix_A[0][0] * Matrix_B[0][0];
else
{
int Middle = length / 2;
/* ————————————————————————————————————————————————————————————
*
* 因为我们传递二维数组得时候,必须指明其一维数组得下标,这
* 会导致我们得递归 length / 2 得规模时无法进行传递参数。
*
* 所以,我们将要声明指向指针得指针,他们将指向指针数组,指针
* 数组内得元素为指针,所以在将元素指向一个一维数组,用来模拟
* 二维数组。
*
* ———————————————————————————————————————————————————————————— */
int **Matrix_A_11 = new int *[Middle];
int **Matrix_A_12 = new int *[Middle];
int **Matrix_A_21 = new int *[Middle];
int **Matrix_A_22 = new int *[Middle];
int **Matrix_B_11 = new int *[Middle];
int **Matrix_B_12 = new int *[Middle];
int **Matrix_B_21 = new int *[Middle];
int **Matrix_B_22 = new int *[Middle];
int **Matrix_C_11 = new int *[Middle];
int **Matrix_C_12 = new int *[Middle];
int **Matrix_C_21 = new int *[Middle];
int **Matrix_C_22 = new int *[Middle];
int **M1 = new int *[Middle];
int **M2 = new int *[Middle];
int **M3 = new int *[Middle];
int **M4 = new int *[Middle];
int **M5 = new int *[Middle];
int **M6 = new int *[Middle];
int **M7 = new int *[Middle];
int **Result_1 = new int *[Middle];
int **Result_2 = new int *[Middle];
for (int i = 0; i < Middle; i++)
{
Matrix_A_11[i] = new int [Middle];
Matrix_A_12[i] = new int [Middle];
Matrix_A_21[i] = new int [Middle];
Matrix_A_22[i] = new int [Middle];
Matrix_B_11[i] = new int [Middle];
Matrix_B_12[i] = new int [Middle];
Matrix_B_21[i] = new int [Middle];
Matrix_B_22[i] = new int [Middle];
Matrix_C_11[i] = new int [Middle];
Matrix_C_12[i] = new int [Middle];
Matrix_C_21[i] = new int [Middle];
Matrix_C_22[i] = new int [Middle];
M1[i] = new int [Middle];
M2[i] = new int [Middle];
M3[i] = new int [Middle];
M4[i] = new int [Middle];
M5[i] = new int [Middle];
M6[i] = new int [Middle];
M7[i] = new int [Middle];
Result_1[i] = new int [Middle];
Result_2[i] = new int [Middle];
}
/* ————————————————————————————————————————————————————————————————————————————————————————————————————————————
*
* 现在,我们将要给这些 "二维数组" 传递特地得值:
*
* Matrix_A_11将获取Matrix_A中 (0 ~ Middle - 1) * (0 ~ Middle - 1) 的数组下标的元素;
*
* Matrix_A_12将获取Matrix_A中 (0 ~ Middle - 1) * (Middle ~ length - 1) 的数组下标的元素;
*
* Matrix_A_21将获取Matrix_A中 (Middle ~ length - 1) * (0 ~ Middle - 1) 的数组下标的元素;
*
* Matrix_A_22将获取Matrix_A中 (Middle ~ length - 1) * (Middle ~ length - 1) 的数组下标的元素。
* * 获取Matrix_B数组元素时,同理。
*
* ———————————————————————————————————————————————————————————————————————————————————————————————————————————— */
for (int i = 0; i < Middle; i++)
{
for (int j = 0; j < Middle; j++)
{
Matrix_A_11[i][j] = Matrix_A[i][j];
Matrix_A_12[i][j] = Matrix_A[i][j + Middle];
Matrix_A_21[i][j] = Matrix_A[i + Middle][j];
Matrix_A_22[i][j] = Matrix_A[i + Middle][j + Middle];
Matrix_B_11[i][j] = Matrix_B[i][j];
Matrix_B_12[i][j] = Matrix_B[i][j + Middle];
Matrix_B_21[i][j] = Matrix_B[i + Middle][j];
Matrix_B_22[i][j] = Matrix_B[i + Middle][j + Middle];
}
}
/* ————————————————————————————————————————————————————————————
*
* 递归进行矩阵的乘法运算
*
* 在递归之前,我们会调用Add和Subtract两个函数进行矩阵之间必
* 要的加减法运算,并用Result_1和Result_2临时存储其结果。
*
* ———————————————————————————————————————————————————————————— */
// M1
Add(Matrix_A_11, Matrix_A_22, Result_1, Middle);
Add(Matrix_B_11, Matrix_B_22, Result_2, Middle);
STRASSEN_ALGORITHM(Result_1, Result_2, M1, Middle);
// M2
Add(Matrix_A_21, Matrix_A_22, Result_1, Middle);
STRASSEN_ALGORITHM(Result_1, Matrix_B_11, M2, Middle);
// M3
Subtract(Matrix_B_12, Matrix_B_22, Result_1, Middle);
STRASSEN_ALGORITHM(Matrix_A_11, Result_1, M3, Middle);
// M4
Subtract(Matrix_B_21, Matrix_B_11, Result_1, Middle);
STRASSEN_ALGORITHM(Matrix_A_22, Result_1, M4, Middle);
// M5
Add(Matrix_A_11, Matrix_A_12, Result_1, Middle);
STRASSEN_ALGORITHM(Result_1, Matrix_B_22, M5, Middle);
// M6
Subtract(Matrix_A_21, Matrix_A_11, Result_1, Middle);
Add(Matrix_B_11, Matrix_B_12, Result_2, Middle);
STRASSEN_ALGORITHM(Result_1, Result_2, M6, Middle);
// M7
Subtract(Matrix_A_12, Matrix_A_22, Result_1, Middle);
Add(Matrix_B_21, Matrix_B_22, Result_2, Middle);
STRASSEN_ALGORITHM(Result_1, Result_2, M7, Middle);
/* ————————————————————————————————————————————————————————————
*
* 现在我们将按照Strassen算法的步骤,将前面递归计算出来的矩
* 阵M1...M7进行特定的加减法运算,从而得出Matrix_C_11, Matrix_C_12
* Matrix_C_21, Matrix_C_22。
*
* 我们会调用Add和Subtract两个函数进行矩阵之间必要的加减法运
* 算,并用Result_1和Result_2临时存储其结果。
*
* ———————————————————————————————————————————————————————————— */
// Matrix_C_11
Add(M1, M4, Result_1, Middle);
Subtract(Result_1, M5, Result_2, Middle);
Add(Result_2, M7, Matrix_C_11, Middle);
// Matrix_C_12
Add(M3, M5, Matrix_C_12, Middle);
// Matrix_C_21
Add(M2, M4, Matrix_C_21, Middle);
// Matrix_C_22
Subtract(M1, M2, Result_1, Middle);
Add(Result_1, M3, Result_2, Middle);
Add(Result_2, M6, Matrix_C_22, Middle);
/* ————————————————————————————————————————————————————————————
*
* 现在我们需要将前面的四个小矩阵"拼"回一个大矩阵
*
* ———————————————————————————————————————————————————————————— */
for (int i = 0; i < Middle; i++)
{
for (int j = 0; j < Middle; j++)
{
Matrix_C[i][j] = Matrix_C_11[i][j];
Matrix_C[i][j + Middle] = Matrix_C_12[i][j];
Matrix_C[i + Middle][j] = Matrix_C_21[i][j];
Matrix_C[i + Middle][j + Middle] = Matrix_C_22[i][j];
}
}
/* ————————————————————————————————————————————————————————————
*
* 最后释放掉申请的动态内存
*
* ———————————————————————————————————————————————————————————— */
for (int i = 0; i < Middle; i++)
{
delete[] Matrix_A_11[i];
delete[] Matrix_A_12[i];
delete[] Matrix_A_21[i];
delete[] Matrix_A_22[i];
delete[] Matrix_B_11[i];
delete[] Matrix_B_12[i];
delete[] Matrix_B_21[i];
delete[] Matrix_B_22[i];
delete[] Matrix_C_11[i];
delete[] Matrix_C_12[i];
delete[] Matrix_C_21[i];
delete[] Matrix_C_22[i];
delete[] M1[i];
delete[] M2[i];
delete[] M3[i];
delete[] M4[i];
delete[] M5[i];
delete[] M6[i];
delete[] M7[i];
delete[] Result_1[i];
delete[] Result_2[i];
}
delete[] Matrix_A_11;
delete[] Matrix_A_12;
delete[] Matrix_A_21;
delete[] Matrix_A_22;
delete[] Matrix_B_11;
delete[] Matrix_B_12;
delete[] Matrix_B_21;
delete[] Matrix_B_22;
delete[] Matrix_C_11;
delete[] Matrix_C_12;
delete[] Matrix_C_21;
delete[] Matrix_C_22;
delete[] M1;
delete[] M2;
delete[] M3;
delete[] M4;
delete[] M5;
delete[] M6;
delete[] M7;
delete[] Result_1;
delete[] Result_2;
}
}
/* ————————————————————————————————————————————————————————————
*
* 输出结果:
*
* 矩阵A:
* 1 2 3 4
* 5 6 7 8
* 9 10 11 12
* 13 14 15 16
*
* 矩阵B:
* 17 18 19 20
* 21 22 23 24
* 25 26 27 28
* 29 30 31 32
*
* 矩阵C = A * B:
* 250 260 270 280
* 618 644 670 696
* 986 1028 1070 1112
* 1354 1412 1470 1528
*
* ———————————————————————————————————————————————————————————— */