–《算法导论第三版》第4章 分治策略
Strassen算法的核心思想是令递归树稍微不那么茂盛一点儿, 即只递归进行7次而不是8次
n
2
\frac{n}{2}
2n ×
n
2
\frac{n}{2}
2n矩阵的乘法。减少一次矩阵乘法带来的代价可能是额外几次
n
2
\frac{n}{2}
2n ×
n
2
\frac{n}{2}
2n矩阵的加法,但只是常数次 。
假定将A,B和C均分解为4个
n
2
\frac{n}{2}
2n ×
n
2
\frac{n}{2}
2n的子矩阵:
A
=
[
A
11
A
12
A
21
A
22
]
,
B
=
[
B
11
B
12
B
21
B
22
]
,
C
=
[
C
11
C
12
C
21
C
22
]
A = \begin{bmatrix} A_{11} & A_{12} \\A_{21} & A_{22} \end{bmatrix}, B = \begin{bmatrix} B_{11} & B_{12} \\B_{21} & B_{22} \end{bmatrix}, C =\begin{bmatrix} C_{11} & C_{12} \\C_{21} & C_{22} \end{bmatrix}
A=[A11A21A12A22],B=[B11B21B12B22],C=[C11C21C12C22]
因此可将公式
C
=
A
⋅
B
C = A \cdot B
C=A⋅B改写成
[
C
11
C
12
C
21
C
22
]
=
[
A
11
A
12
A
21
A
22
]
⋅
[
B
11
B
12
B
21
B
22
]
\begin{bmatrix} C_{11} & C_{12} \\C_{21} & C_{22} \end{bmatrix} = \begin{bmatrix} A_{11} & A_{12} \\A_{21} & A_{22} \end{bmatrix} \cdot \begin{bmatrix} B_{11} & B_{12} \\B_{21} & B_{22} \end{bmatrix}
[C11C21C12C22]=[A11A21A12A22]⋅[B11B21B12B22]
则:
C
11
=
A
11
⋅
B
11
+
A
12
⋅
B
21
C
12
=
A
11
⋅
B
12
+
A
22
⋅
B
21
C
21
=
A
21
⋅
B
11
+
A
22
⋅
B
21
C
22
=
A
21
⋅
B
12
+
A
22
⋅
B
22
C_{11} = A_{11} \cdot B_{11} + A_{12} \cdot B_{21}\\ C_{12} = A_{11} \cdot B_{12} + A_{22} \cdot B_{21}\\ C_{21} = A_{21} \cdot B_{11} + A_{22} \cdot B_{21}\\ C_{22} = A_{21} \cdot B_{12} + A_{22} \cdot B_{22}
C11=A11⋅B11+A12⋅B21C12=A11⋅B12+A22⋅B21C21=A21⋅B11+A22⋅B21C22=A21⋅B12+A22⋅B22
Strassen算法包含四步
-
按上述方法将A,B,C分解(花费时间 Θ ( 1 ) \Theta(1) Θ(1))。
-
如下创建10个 n 2 \frac{n}{2} 2n × n 2 \frac{n}{2} 2n的矩阵 S 1 , S 2 , . . . , S 10 S_1, S_2, ..., S_{10} S1,S2,...,S10,(花费时间 Θ ( n 2 ) \Theta(n^2) Θ(n2))。
S 1 = B 12 − B 22 S 2 = A 11 + A 12 S 3 = A 21 + A 22 S 4 = B 21 − B 11 S 5 = A 11 + A 22 S 6 = B 11 + B 22 S 7 = A 12 − A 22 S 8 = B 21 + B 22 S 9 = A 11 − A 21 S 10 = B 11 + B 12 S_1 = B_{12} - B_{22}\\ S_2 = A_{11} + A_{12}\\ S_3 = A_{21} + A_{22}\\ S_4 = B_{21} - B_{11}\\ S_5 = A_{11} + A_{22}\\ S_6 = B_{11} + B_{22}\\ S_7 = A_{12} - A_{22}\\ S_8 = B_{21} + B_{22}\\ S_9 = A_{11} - A_{21}\\ S_{10} = B_{11} + B_{12} S1=B12−B22S2=A11+A12S3=A21+A22S4=B21−B11S5=A11+A22S6=B11+B22S7=A12−A22S8=B21+B22S9=A11−A21S10=B11+B12 -
如下递归地计算7个矩阵积 P 1 , P 2 , . . . , P 7 P_1, P_2, ..., P_7 P1,P2,...,P7。每个矩阵 P i P_i Pi都是 n 2 \frac{n}{2} 2n × n 2 \frac{n}{2} 2n的。
P
1
=
A
11
⋅
S
1
=
A
11
⋅
B
12
−
A
11
⋅
B
22
P
2
=
S
2
⋅
B
22
=
A
11
⋅
B
22
+
A
12
⋅
B
22
P
3
=
S
3
⋅
B
11
=
A
21
⋅
B
11
+
A
22
⋅
B
11
P
4
=
A
22
⋅
S
4
=
A
22
⋅
B
21
−
A
22
⋅
B
11
P
5
=
S
5
⋅
S
6
=
A
11
⋅
B
11
+
A
11
⋅
B
22
+
A
22
⋅
B
11
+
A
22
⋅
B
22
P
6
=
S
7
⋅
S
8
=
A
12
⋅
B
21
+
A
12
⋅
B
22
−
A
22
⋅
B
21
−
A
22
⋅
B
22
P
7
=
S
9
⋅
S
10
=
A
11
⋅
B
11
+
A
11
⋅
B
12
−
A
21
⋅
B
11
−
A
21
⋅
B
12
P_1 = A_{11} \cdot S_1 = A_{11} \cdot B_{12} - A_{11} \cdot B_{22}\\ P_2 = S_2 \cdot B_{22} = A_{11} \cdot B_{22} + A_{12} \cdot B_{22}\\ P_3 = S_3 \cdot B_{11} = A_{21} \cdot B_{11} + A_{22} \cdot B_{11}\\ P_4 = A_{22} \cdot S_4 = A_{22}\cdot B_{21} - A_{22} \cdot B_{11}\\ P_5 = S_5 \cdot S_6 = A_{11} \cdot B_{11} + A_{11} \cdot B_{22} + A_{22} \cdot B_{11} + A_{22} \cdot B_{22}\\ P_6 = S_7 \cdot S_8 = A_{12} \cdot B_{21} + A{12} \cdot B_{22} - A_{22} \cdot B_{21} - A_{22} \cdot B_{22}\\ P_7 = S_9 \cdot S_{10}= A_{11} \cdot B_{11} + A_{11} \cdot B_{12} - A_{21} \cdot B_{11} - A_{21} \cdot B_{12}
P1=A11⋅S1=A11⋅B12−A11⋅B22P2=S2⋅B22=A11⋅B22+A12⋅B22P3=S3⋅B11=A21⋅B11+A22⋅B11P4=A22⋅S4=A22⋅B21−A22⋅B11P5=S5⋅S6=A11⋅B11+A11⋅B22+A22⋅B11+A22⋅B22P6=S7⋅S8=A12⋅B21+A12⋅B22−A22⋅B21−A22⋅B22P7=S9⋅S10=A11⋅B11+A11⋅B12−A21⋅B11−A21⋅B12
注意,上述公式中只有中间一列需要计算。
- 通过
P
i
P_i
Pi计算
C
11
,
C
12
,
C
21
,
C
22
C_{11}, C_{12}, C_{21}, C_{22}
C11,C12,C21,C22,(花费时间
Θ
(
n
2
)
\Theta(n^2)
Θ(n2))。
C 11 = P 5 + P 4 − P 2 + P 6 C 12 = P 1 + P 2 C 21 = P 3 + P 4 C 22 = P 5 + P 1 − P 3 − P 7 C_{11} = P_5 + P_4 - P_2 + P_6\\ C_{12} = P_1 + P_2\\ C_{21} = P_3 + P_4\\ C_{22} = P_5 + P_1 - P_3 - P_7 C11=P5+P4−P2+P6C12=P1+P2C21=P3+P4C22=P5+P1−P3−P7
可得如下递归式:
T
(
n
)
=
{
Θ
(
1
)
若
n
=
1
7
T
(
n
2
)
+
Θ
(
n
2
)
若
n
>
1
T(n) = \begin{cases}\Theta(1) & 若n = 1\\7T(\frac{n}{2}) + \Theta(n^2) & 若n> 1\end{cases}
T(n)={Θ(1)7T(2n)+Θ(n2)若n=1若n>1
由递归树或主方法可得:
T
(
n
)
=
Θ
(
n
l
g
7
)
T(n) = \Theta(n^{lg7})
T(n)=Θ(nlg7)
C++实现如下:
#include "stdafx.h"
#include <stdio.h>
#include <iostream>
#include <windows.h>
#include <ctime>
using namespace std;
template <typename T>
class Strassen
{
public:
void ADD(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size);
void SUB(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size);
void NormalMul(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size);
void StrassenMul(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size);
void FillMatrix(T ** MatrixA, T ** MatrixB, int size);//给A、B矩阵赋初值
int GetMatrixSum(T ** Matrix, int size);
//用来计算矩阵各个元素的和,如果两种算法得出的矩阵的和相等则认为算法正确。
};
template <typename T>
void Strassen<T>::ADD(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
MatrixResult[i][j] = MatrixA[i][j] + MatrixB[i][j];
}
}
}
template <typename T>
void Strassen<T>::SUB(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
MatrixResult[i][j] = MatrixA[i][j] - MatrixB[i][j];
}
}
}
template <typename T>
void Strassen<T>::NormalMul(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
MatrixResult[i][j] = 0;
for(int k = 0; k < size; k++)
MatrixResult[i][j] += MatrixA[i][k] * MatrixB[k][j];
}
}
}
template <typename T>
void Strassen<T>::FillMatrix(T ** MatrixA, T ** MatrixB, int size)//给A、B矩阵赋初值
{
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
MatrixA[i][j] = MatrixB[i][j] = rand() % 5;
}
}
}
template <typename T>
void Strassen<T>::StrassenMul(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
// if ( size <= 64 )
//分治门槛,小于这个值时不再进行递归计算,而是采用常规矩阵计算方法
// {
// NormalMul(MatrixA, MatrixB, MatrixResult, size);
// }
if(size == 1)
{
MatrixResult[0][0] = MatrixA[0][0] * MatrixB[0][0];
}
else
{
int half_size = size / 2;
T ** A11; T ** A12; T ** A21; T ** A22;
T ** B11; T ** B12; T ** B21; T ** B22;
T ** C11; T ** C12; T ** C21; T** C22;
T ** M1; T ** M2; T ** M3; T ** M4; T ** M5; T ** M6; T ** M7;
T ** MatrixTemp1; T ** MatrixTemp2;
A11 = new int * [half_size];
A12 = new int * [half_size];
A21 = new int * [half_size];
A22 = new int * [half_size];
B11 = new int * [half_size];
B12 = new int * [half_size];
B21 = new int * [half_size];
B22 = new int * [half_size];
C11 = new int * [half_size];
C12 = new int * [half_size];
C21 = new int * [half_size];
C22 = new int * [half_size];
M1 = new int * [half_size];
M2 = new int * [half_size];
M3 = new int * [half_size];
M4 = new int * [half_size];
M5 = new int * [half_size];
M6 = new int * [half_size];
M7 = new int * [half_size];
MatrixTemp1 = new int * [half_size];
MatrixTemp2 = new int * [half_size];
for(int i = 0; i < half_size; i++)
{
A11[i] = new int[half_size];
A12[i] = new int[half_size];
A21[i] = new int[half_size];
A22[i] = new int[half_size];
B11[i] = new int[half_size];
B12[i] = new int[half_size];
B21[i] = new int[half_size];
B22[i] = new int[half_size];
C11[i] = new int[half_size];
C12[i] = new int[half_size];
C21[i] = new int[half_size];
C22[i] = new int[half_size];
M1[i] = new int[half_size];
M2[i] = new int[half_size];
M3[i] = new int[half_size];
M4[i] = new int[half_size];
M5[i] = new int[half_size];
M6[i] = new int[half_size];
M7[i] = new int[half_size];
MatrixTemp1[i] = new int[half_size];
MatrixTemp2[i] = new int[half_size];
}
//赋值
for(int i = 0; i < half_size; i++)
{
for(int j = 0; j < half_size; j++)
{
A11[i][j] = MatrixA[i][j];
A12[i][j] = MatrixA[i][j+half_size];
A21[i][j] = MatrixA[i+half_size][j];
A22[i][j] = MatrixA[i+half_size][j+half_size];
B11[i][j] = MatrixB[i][j];
B12[i][j] = MatrixB[i][j+half_size];
B21[i][j] = MatrixB[i+half_size][j];
B22[i][j] = MatrixB[i+half_size][j+half_size];
}
}
//calculate M1
ADD(A11, A22, MatrixTemp1, half_size);
ADD(B11, B22, MatrixTemp2, half_size);
StrassenMul(MatrixTemp1, MatrixTemp2, M1,half_size);
//calculate M2
ADD(A21, A22, MatrixTemp1, half_size);
StrassenMul(MatrixTemp1, B11, M2, half_size);
//calculate M3
SUB(B12, B22, MatrixTemp1, half_size);
StrassenMul(A11, MatrixTemp1, M3, half_size);
//calculate M4
SUB(B21, B11, MatrixTemp1, half_size);
StrassenMul(A22, MatrixTemp1, M4, half_size);
//calculate M5
ADD(A11, A12, MatrixTemp1, half_size);
StrassenMul(MatrixTemp1, B22, M5, half_size);
//calculate M6
SUB(A21, A11, MatrixTemp1, half_size);
ADD(B11, B12, MatrixTemp2, half_size);
StrassenMul(MatrixTemp1, MatrixTemp2, M6, half_size);
//calculate M7
SUB(A12, A22, MatrixTemp1, half_size);
ADD(B21, B22, MatrixTemp2, half_size);
StrassenMul(MatrixTemp1, MatrixTemp2, M7, half_size);
//C11
ADD(M1, M4, C11, half_size);
SUB(C11, M5, C11, half_size);
ADD(C11, M7, C11, half_size);
//C12
ADD(M3, M5, C12, half_size);
//C21
ADD(M2, M4, C21, half_size);
//C22
SUB(M1, M2, C22, half_size);
ADD(C22, M3, C22, half_size);
ADD(C22, M6, C22, half_size);
//赋值
for(int i = 0; i < half_size; i++)
{
for(int j = 0; j < half_size; j++)
{
MatrixResult[i][j] = C11[i][j];
MatrixResult[i][j+half_size] = C12[i][j];
MatrixResult[i+half_size][j] = C21[i][j];
MatrixResult[i+half_size][j+half_size] = C22[i][j];
}
}
//释放申请的内存
for(int i = 0; i < half_size; i++)
{
delete[] A11[i];
delete[] A12[i];
delete[] A21[i];
delete[] A22[i];
delete[] B11[i];
delete[] B12[i];
delete[] B21[i];
delete[] B22[i];
delete[] C11[i];
delete[] C12[i];
delete[] C21[i];
delete[] C22[i];
delete[] M1[i];
delete[] M2[i];
delete[] M3[i];
delete[] M4[i];
delete[] M5[i];
delete[] M6[i];
delete[] M7[i];
delete[] MatrixTemp1[i];
delete[] MatrixTemp2[i];
}
delete[] A11;
delete[] A12;
delete[] A21;
delete[] A22;
delete[] B11;
delete[] B12;
delete[] B21;
delete[] B22;
delete[] C11;
delete[] C12;
delete[] C21;
delete[] C22;
delete[] M1;
delete[] M2;
delete[] M3;
delete[] M4;
delete[] M5;
delete[] M6;
delete[] M7;
delete[] MatrixTemp1;
delete[] MatrixTemp2;
}
}
template <typename T>
int Strassen<T>::GetMatrixSum(T ** Matrix, int size)
{
int sum = 0;
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
sum += Matrix[i][j];
}
}
return sum;
}
int main()
{
long startTime_normal, endTime_normal;
long startTime_strasse, endTime_strassen;
//srand(time(0));
Strassen<int> stra;
int N;
cout<<"please input the size of Matrix,and the size must be the power of 2:"<<endl;
cin>>N;
int ** Matrix1 = new int * [N];
int ** Matrix2 = new int * [N];
int ** Matrix3 = new int * [N];
for(int i=0;i<N;i++)
{
Matrix1[i] = new int[N];
Matrix2[i] = new int[N];
Matrix3[i] = new int[N];
}
stra.FillMatrix(Matrix1, Matrix2,N);
cout << "朴素算法开始时间:" << (startTime_normal = clock()) << endl;
stra.NormalMul(Matrix1, Matrix2, Matrix3,N);
cout << "朴素算法结束时间:" << (endTime_normal = clock()) << endl;
cout << "总时间:" << endTime_normal-startTime_normal << endl;
cout << "sum = " << stra.GetMatrixSum(Matrix3,N) << ';' << endl;
cout << "Strassen算法开始时间:" << (startTime_strasse= clock()) << endl;
stra.StrassenMul(Matrix1,Matrix2,Matrix3,N);
cout << "Strassen算法结束时间:" << (endTime_strassen = clock()) << endl;
cout << "总时间:" << endTime_strassen-startTime_strasse << endl;
cout << "sum = " << stra.GetMatrixSum(Matrix3,N) << ';' << endl;
}