Strassen算法

Strassen算法

时间算法度 Θ ( n l o g 2 7 ) ≈ Θ ( n 2.8074 ) Θ\left(n^{log_{2}7}\right)≈Θ\left(n^{2.8074}\right) Θ(nlog27)Θ(n2.8074)


该算法主旨
S t r a s s e n Strassen Strassen算法的核心是让递归树不那么茂盛一点,即只递归 7 7 7次而不是 8 8 8 n / 2 × n / 2 n/2×n/2 n/2×n/2矩阵的乘法。

A A A B B B为环 R R R上的两个平方矩阵。我们想将矩阵乘积 C C C计算为:

   C = A B          A , B , C ∈ R 2 n × 2 n C=AB\ \ \ \ \ \ \ \ A,B,C∈R^{2^{n}×2^{n}} C=AB        ABCR2n×2n

如果矩阵 A A A B B B不是 2 n = 2 n 2^n = 2^n 2n=2n样式,我们将用 0 0 0填充缺少的行和列。

我们将A、B和C划分为大小相等的块矩阵:

   A = [ A 1 , 1 A 1 , 2 A 2 , 1 A 2 , 2 ] A=\begin{bmatrix}A_{1,1} & A_{1,2} \\A_{2,1} & A_{2,2}\end{bmatrix} A=[A1,1A2,1A1,2A2,2]   B = [ B 1 , 1 B 1 , 2 B 2 , 1 B 2 , 2 ] B=\begin{bmatrix}B_{1,1} & B_{1,2} \\B_{2,1} & B_{2,2}\end{bmatrix} B=[B1,1B2,1B1,2B2,2]   C = [ C 1 , 1 C 1 , 2 C 2 , 1 C 2 , 2 ] C=\begin{bmatrix}C_{1,1} & C_{1,2} \\C_{2,1} & C_{2,2}\end{bmatrix} C=[C1,1C2,1C1,2C2,2]

根据朴素算法可以得出:

   C 1 , 1 = A 1 , 1 × B 1 , 1 + A 1 , 2 × B 2 , 1 C_{1,1}=A_{1,1}×B_{1,1}+A_{1,2}×B_{2,1} C1,1=A1,1×B1,1+A1,2×B2,1
   C 1 , 2 = A 1 , 1 × B 1 , 2 + A 1 , 2 × B 2 , 2 C_{1,2}=A_{1,1}×B_{1,2}+A_{1,2}×B_{2,2} C1,2=A1,1×B1,2+A1,2×B2,2
   C 2 , 1 = A 2 , 1 × B 1 , 1 + A 2 , 2 × B 2 , 1 C_{2,1}=A_{2,1}×B_{1,1}+A_{2,2}×B_{2,1} C2,1=A2,1×B1,1+A2,2×B2,1
   C 2 , 2 = A 2 , 1 × B 1 , 2 + A 2 , 2 × B 2 , 2 C_{2,2}=A_{2,1}×B_{1,2}+A_{2,2}×B_{2,2} C2,2=A2,1×B1,2+A2,2×B2,2

然而,我们可以看到,有了这个结构,我们依然没有减少乘法的数量。我们仍然需要 8 8 8个乘法来计算‎‎ C ‎ ‎ i , j ‎ ‎ C_{‎‎i,j}‎‎ Ci,j矩阵,这与使用标准矩阵乘法时所需的乘法数相同。‎‎

S t r a s s e n Strassen Strassen算法定义了新的矩阵:‎

   M 1 = ( A 1 , 1 + A 2 , 2 ) ( B 1 , 1 + B 2 , 2 ) M_{1}=\left(A_{1,1}+A_{2,2}\right)\left(B_{1,1}+B_{2,2}\right) M1=(A1,1+A2,2)(B1,1+B2,2)
   M 2 = ( A 2 , 1 + A 2 , 2 ) B 1 , 1 M_{2}=\left(A_{2,1}+A_{2,2}\right)B_{1,1} M2=(A2,1+A2,2)B1,1
   M 3 = A 1 , 1 ( B 1 , 2 − B 2 , 2 ) M_{3}=A_{1,1}\left(B_{1,2}−B_{2,2}\right) M3=A1,1(B1,2B2,2)
   M 4 = A 2 , 2 ( B 2 , 1 − B 1 , 1 ) M_{4}=A_{2,2}\left(B_{2,1}−B_{1,1}\right) M4=A2,2(B2,1B1,1)
   M 5 = ( A 1 , 1 + A 1 , 2 ) B 2 , 2 M_{5}=\left(A_{1,1}+A_{1,2}\right)B_{2,2} M5=(A1,1+A1,2)B2,2
   M 6 = ( A 2 , 1 − A 1 , 1 ) ( B 1 , 1 + B 1 , 2 ) M_{6}=\left(A_{2,1}−A_{1,1}\right)\left(B_{1,1}+B_{1,2}\right) M6=(A2,1A1,1)(B1,1+B1,2)
   M 7 = ( A 1 , 2 − A 2 , 2 ) ( B 2 , 1 + B 2 , 2 ) M_{7}=\left(A_{1,2}−A_{2,2}\right)\left(B_{2,1}+B_{2,2}\right) M7=(A1,2A2,2)(B2,1+B2,2)

仅使用 7 7 7个乘法(每个 M k M_{k} Mk一个乘法),而不是 8 8 8。我们现在可以用 M k M_{k} Mk的不同组合进行加减法运算来表示 C i , j C_{i,j} Ci,j

   C 1 , 1 = M 1 + M 4 − M 5 + M 7 C_{1,1}=M_{1}+M_{4}-M_{5}+M_{7} C1,1=M1+M4M5+M7
   C 1 , 2 = M 3 + M 5 C_{1,2}=M_{3}+M_{5} C1,2=M3+M5
   C 2 , 1 = M 2 + M 4 C_{2,1}=M_{2}+M_{4} C2,1=M2+M4
   C 2 , 2 = M 1 − M 2 + M 3 + M 6 C_{2,2}=M_{1}-M_{2}+M_{3}+M_{6} C2,2=M1M2+M3+M6

                                ———来自维基百科

验证一下:

   C 1 , 2 = M 3 + M 5 C_{1,2}=M3+M5 C1,2=M3+M5
          = A 1 , 1 ( B 1 , 2 − B 2 , 2 ) + ( A 1 , 1 + A 1 , 2 ) B 2 , 2 \ =A_{1,1}\left(B_{1,2}−B_{2,2}\right)+\left(A_{1,1}+A_{1,2}\right)B_{2,2}  =A1,1(B1,2B2,2)+(A1,1+A1,2)B2,2
          = A 1 , 1 × B 1 , 2 − A 1 , 1 × B 2 , 2 + A 1 , 1 × B 2 , 2 + A 1 , 2 × B 2 , 2 \ =A_{1,1}×B_{1,2}-A_{1,1}×B_{2,2}+A_{1,1}×B_{2,2}+A_{1,2}×B_{2,2}  =A1,1×B1,2A1,1×B2,2+A1,1×B2,2+A1,2×B2,2
          = A 1 , 1 × B 1 , 2 + A 1 , 2 × B 2 , 2 \ =A_{1,1}×B_{1,2}+A_{1,2}×B_{2,2}  =A1,1×B1,2+A1,2×B2,2


伪代码

// 下面的伪代码是参考的别人的,我自己实在懒得写了...

STRASSEN(A, B)
1   Length is the row or line of A and B
2   let C be a new (Length * Length) matrix
3   if Length == 1
4       C = A * B
5   else partition A,B,and C as in equations(4,9)
6       S1 = B12 - B22
7       S2 = A11 - A12
8       S3 = A21 + A22
9       S4 = B21 - B11
10      S5 = A11 + A22
11      S6 = B11 + B22
12      S7 = A12 - A22
13      S8 = B21 + B22
14      S9 = A11 - A21
15      S10 = B11 + B12
16      P1 = STRASSEN(A11, S1)
17      P2 = STRASSEN(A11, B22)
18      P3 = STRASSEN(S3, B11)
19      P4 = Strassen(A22, S4)
20      P5 = STRASSEN(S5, S6)
21      P6 = STRASSEN(S7, S8)
22      P7 = STRASSEN(S9, S10)
23      C11 = P5 + P4 - P2 + P6
24      C12 = P1 + P2
25      C21 = P3 + P4
26      C22 = P5 + P1 - P3 - P7
27      return C

C++代码


// 以下代码参考自 https://blog.csdn.net/zhuangxiaobin/article/details/36476769

/* ————————————————————————————————————————————————————————————         
 *          
 * 矩阵加法函数,时间算法度:O(n^2)         
 *          
 * ———————————————————————————————————————————————————————————— */
void Add(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
    for (int i = 0; i < length; i++)        
        for (int j = 0; j < length; j++)            
            Matrix_C[i][j] = Matrix_A[i][j] + Matrix_B[i][j];
}

/* ————————————————————————————————————————————————————————————         
 *          
 * 矩阵减法函数,时间算法度:O(n^2)         
 *          
 * ———————————————————————————————————————————————————————————— */
void Subtract(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
    for (int i = 0; i < length; i++)        
        for (int j = 0; j < length; j++)            
            Matrix_C[i][j] = Matrix_A[i][j] - Matrix_B[i][j];
}

/* ————————————————————————————————————————————————————————————————————————————————————————————————————————————————————————         
 * Strassen算法(矩阵乘法函数)
 * 
 * 参数:
 * 	Matrix_A为一个指向指针的指针,它应指向一个指针数组,指针数组中的元素又指向另一个普通数组,形成一个二维数组。
 *	Matrix_B和Matrix_C同上。
 *	第4个参数(length)为矩阵的长或宽。
 * 	注意:矩阵的长于宽应相等,如若不相等,则传入最长的那个边,函数会自动填充至length * length。
 * ———————————————————————————————————————————————————————————————————————————————————————————————————————————————————————— */
void STRASSEN_ALGORITHM(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
    if (length == 1)        
        Matrix_C[0][0] = Matrix_A[0][0] * Matrix_B[0][0];    
    else    
    {        
        int Middle = length / 2;
        
        
        /* ————————————————————————————————————————————————————————————         
         *         
         * 因为我们传递二维数组得时候,必须指明其一维数组得下标,这         
         * 会导致我们得递归 length / 2 得规模时无法进行传递参数。         
         *          
         * 所以,我们将要声明指向指针得指针,他们将指向指针数组,指针         
         * 数组内得元素为指针,所以在将元素指向一个一维数组,用来模拟         
         * 二维数组。         
         *          
         * ———————————————————————————————————————————————————————————— */
         
        int **Matrix_A_11 = new int *[Middle];        
        int **Matrix_A_12 = new int *[Middle];        
        int **Matrix_A_21 = new int *[Middle];        
        int **Matrix_A_22 = new int *[Middle];
        
        int **Matrix_B_11 = new int *[Middle];        
        int **Matrix_B_12 = new int *[Middle];        
        int **Matrix_B_21 = new int *[Middle];        
        int **Matrix_B_22 = new int *[Middle];
        
        int **Matrix_C_11 = new int *[Middle];        
        int **Matrix_C_12 = new int *[Middle];        
        int **Matrix_C_21 = new int *[Middle];        
        int **Matrix_C_22 = new int *[Middle];
        
        int **M1 = new int *[Middle];        
        int **M2 = new int *[Middle];        
        int **M3 = new int *[Middle];        
        int **M4 = new int *[Middle];        
        int **M5 = new int *[Middle];        
        int **M6 = new int *[Middle];        
        int **M7 = new int *[Middle];
        
        int **Result_1 = new int *[Middle];        
        int **Result_2 = new int *[Middle];
        
        for (int i = 0; i < Middle; i++)        
        {
            Matrix_A_11[i] = new int [Middle];            
            Matrix_A_12[i] = new int [Middle];            
            Matrix_A_21[i] = new int [Middle];            
            Matrix_A_22[i] = new int [Middle];
            
            Matrix_B_11[i] = new int [Middle];            
            Matrix_B_12[i] = new int [Middle];            
            Matrix_B_21[i] = new int [Middle];            
            Matrix_B_22[i] = new int [Middle];
            
            Matrix_C_11[i] = new int [Middle];            
            Matrix_C_12[i] = new int [Middle];            
            Matrix_C_21[i] = new int [Middle];            
            Matrix_C_22[i] = new int [Middle];
            
            M1[i] = new int [Middle];            
            M2[i] = new int [Middle];            
            M3[i] = new int [Middle];            
            M4[i] = new int [Middle];            
            M5[i] = new int [Middle];            
            M6[i] = new int [Middle];            
            M7[i] = new int [Middle];
            
            Result_1[i] = new int [Middle];            
            Result_2[i] = new int [Middle];
        }
        
        
        /* ————————————————————————————————————————————————————————————————————————————————————————————————————————————         
         *         
         * 现在,我们将要给这些 "二维数组" 传递特地得值:         
         *          
         * Matrix_A_11将获取Matrix_A中 (0 ~ Middle - 1) * (0 ~ Middle - 1) 的数组下标的元素;         
         *          
         * Matrix_A_12将获取Matrix_A中 (0 ~ Middle - 1) * (Middle ~ length - 1) 的数组下标的元素;         
         *          
         * Matrix_A_21将获取Matrix_A中 (Middle ~ length - 1) * (0 ~ Middle - 1) 的数组下标的元素;         
         *          
         * Matrix_A_22将获取Matrix_A中 (Middle ~ length - 1) * (Middle ~ length - 1) 的数组下标的元素。         
         *          * 获取Matrix_B数组元素时,同理。         
         *          
         * ———————————————————————————————————————————————————————————————————————————————————————————————————————————— */
        
        for (int i = 0; i < Middle; i++)        
        {            
            for (int j = 0; j < Middle; j++)            
            {                
                Matrix_A_11[i][j] = Matrix_A[i][j];                
                Matrix_A_12[i][j] = Matrix_A[i][j + Middle];                
                Matrix_A_21[i][j] = Matrix_A[i + Middle][j];                
                Matrix_A_22[i][j] = Matrix_A[i + Middle][j + Middle];
                Matrix_B_11[i][j] = Matrix_B[i][j];                
                Matrix_B_12[i][j] = Matrix_B[i][j + Middle];                
                Matrix_B_21[i][j] = Matrix_B[i + Middle][j];                
                Matrix_B_22[i][j] = Matrix_B[i + Middle][j + Middle];            
             }        
         }


        /* ————————————————————————————————————————————————————————————         
         *          
         * 递归进行矩阵的乘法运算         
         *          
         * 在递归之前,我们会调用Add和Subtract两个函数进行矩阵之间必         
         * 要的加减法运算,并用Result_1和Result_2临时存储其结果。         
         *          
         * ———————————————————————————————————————————————————————————— */
        
        // M1        
        Add(Matrix_A_11, Matrix_A_22, Result_1, Middle);        
        Add(Matrix_B_11, Matrix_B_22, Result_2, Middle);        
        STRASSEN_ALGORITHM(Result_1, Result_2, M1, Middle);
        
        // M2        
        Add(Matrix_A_21, Matrix_A_22, Result_1, Middle);        
        STRASSEN_ALGORITHM(Result_1, Matrix_B_11, M2, Middle);
        
        // M3        
        Subtract(Matrix_B_12, Matrix_B_22, Result_1, Middle);        
        STRASSEN_ALGORITHM(Matrix_A_11, Result_1, M3, Middle);
        
        // M4        
        Subtract(Matrix_B_21, Matrix_B_11, Result_1, Middle);        
        STRASSEN_ALGORITHM(Matrix_A_22, Result_1, M4, Middle);
        
        // M5        
        Add(Matrix_A_11, Matrix_A_12, Result_1, Middle);        
        STRASSEN_ALGORITHM(Result_1, Matrix_B_22, M5, Middle);
        
        // M6        
        Subtract(Matrix_A_21, Matrix_A_11, Result_1, Middle);        
        Add(Matrix_B_11, Matrix_B_12, Result_2, Middle);        
        STRASSEN_ALGORITHM(Result_1, Result_2, M6, Middle);
        
        // M7        
        Subtract(Matrix_A_12, Matrix_A_22, Result_1, Middle);        
        Add(Matrix_B_21, Matrix_B_22, Result_2, Middle);        
        STRASSEN_ALGORITHM(Result_1, Result_2, M7, Middle);


        /* ————————————————————————————————————————————————————————————         
         *         
         * 现在我们将按照Strassen算法的步骤,将前面递归计算出来的矩         
         * 阵M1...M7进行特定的加减法运算,从而得出Matrix_C_11, Matrix_C_12         
         * Matrix_C_21, Matrix_C_22。         
         *          
         * 我们会调用Add和Subtract两个函数进行矩阵之间必要的加减法运         
         * 算,并用Result_1和Result_2临时存储其结果。         
         *          
         * ———————————————————————————————————————————————————————————— */

	// Matrix_C_11        
        Add(M1, M4, Result_1, Middle);        
        Subtract(Result_1, M5, Result_2, Middle);        
        Add(Result_2, M7, Matrix_C_11, Middle);
        
        // Matrix_C_12        
        Add(M3, M5, Matrix_C_12, Middle);
        
        // Matrix_C_21        
        Add(M2, M4, Matrix_C_21, Middle);
        
        // Matrix_C_22        
        Subtract(M1, M2, Result_1, Middle);        
        Add(Result_1, M3, Result_2, Middle);        
        Add(Result_2, M6, Matrix_C_22, Middle);


        /* ————————————————————————————————————————————————————————————         
         *          
         * 现在我们需要将前面的四个小矩阵"拼"回一个大矩阵         
         *          
         * ———————————————————————————————————————————————————————————— */
                 
	for (int i = 0; i < Middle; i++)        
	{            
	    for (int j = 0; j < Middle; j++)            
	    {
                Matrix_C[i][j] = Matrix_C_11[i][j];                
                Matrix_C[i][j + Middle] = Matrix_C_12[i][j];                
                Matrix_C[i + Middle][j] = Matrix_C_21[i][j];                
                Matrix_C[i + Middle][j + Middle] = Matrix_C_22[i][j];                            
            }        
	}
	
	
        /* ————————————————————————————————————————————————————————————         
         *          
         * 最后释放掉申请的动态内存         
         *          
         * ———————————————————————————————————————————————————————————— */                 
	for (int i = 0; i < Middle; i++)         
	{            
	    delete[] Matrix_A_11[i];            
	    delete[] Matrix_A_12[i];            
	    delete[] Matrix_A_21[i];            
	    delete[] Matrix_A_22[i];
	    
            delete[] Matrix_B_11[i];            
            delete[] Matrix_B_12[i];            
            delete[] Matrix_B_21[i];            
            delete[] Matrix_B_22[i];
            
            delete[] Matrix_C_11[i];            
            delete[] Matrix_C_12[i];            
            delete[] Matrix_C_21[i];            
            delete[] Matrix_C_22[i];
            
            delete[] M1[i];            
            delete[] M2[i];            
            delete[] M3[i];            
            delete[] M4[i];            
            delete[] M5[i];            
            delete[] M6[i];            
            delete[] M7[i];
            
            delete[] Result_1[i];            
            delete[] Result_2[i];
                     
        }
	delete[] Matrix_A_11;            
	delete[] Matrix_A_12;            
	delete[] Matrix_A_21;            
	delete[] Matrix_A_22;
	
	delete[] Matrix_B_11;            
	delete[] Matrix_B_12;            
	delete[] Matrix_B_21;            
	delete[] Matrix_B_22;
	
        delete[] Matrix_C_11;            
        delete[] Matrix_C_12;            
        delete[] Matrix_C_21;            
        delete[] Matrix_C_22;
        
        delete[] M1;            
        delete[] M2;            
        delete[] M3;            
        delete[] M4;            
        delete[] M5;            
        delete[] M6;            
        delete[] M7;
        
        delete[] Result_1;            
        delete[] Result_2;
        
    }    
}

自己简化了一丢丢的代码

/* ———————————————————————————————————————————————————————————— 
 * 自己改进了一丢丢的Strassen算法: 
 *  
 * 尽量的去减小前面Strassen算法的空间复杂度 
 *  
 * 这里主要是利用其数组下标和指针的特性,将原本需要动态申请的A, B,  
 * C的12个(n / 2) * (n / 2)的小矩阵,改成了规模为12个n / 2的一 
 * 维数组。 
 *  
 * ———————————————————————————————————————————————————————————— */

void STRASSEN_ALGORITHM(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
    if (length == 1)        
        Matrix_C[0][0] = Matrix_A[0][0] * Matrix_B[0][0];    
    else    
    {        
        int Middle = length / 2;
        
        int **Matrix_A_11 = new int *[Middle];        
        int **Matrix_A_12 = new int *[Middle];        
        int **Matrix_A_21 = new int *[Middle];        
        int **Matrix_A_22 = new int *[Middle];
        
        int **Matrix_B_11 = new int *[Middle];        
        int **Matrix_B_12 = new int *[Middle];        
        int **Matrix_B_21 = new int *[Middle];        
        int **Matrix_B_22 = new int *[Middle];
        
        int **Matrix_C_11 = new int *[Middle];        
        int **Matrix_C_12 = new int *[Middle];        
        int **Matrix_C_21 = new int *[Middle];        
        int **Matrix_C_22 = new int *[Middle];
        
        int **M1 = new int *[Middle];        
        int **M2 = new int *[Middle];        
        int **M3 = new int *[Middle];        
        int **M4 = new int *[Middle];        
        int **M5 = new int *[Middle];        
        int **M6 = new int *[Middle];        
        int **M7 = new int *[Middle];
        
        int **Result_1 = new int *[Middle];        
        int **Result_2 = new int *[Middle];


        for (int i = 0; i < Middle; i++)        
        {                        
            M1[i] = new int [Middle];            
            M2[i] = new int [Middle];            
            M3[i] = new int [Middle];            
            M4[i] = new int [Middle];            
            M5[i] = new int [Middle];            
            M6[i] = new int [Middle];            
            M7[i] = new int [Middle];
            
            Result_1[i] = new int [Middle];            
            Result_2[i] = new int [Middle];
        }


        for (int i = 0; i < Middle; i++)        
        {            
            Matrix_A_11[i] = Matrix_A[i];            
            Matrix_A_12[i] = Matrix_A[i] + Middle;            
            Matrix_A_21[i] = Matrix_A[i + Middle];            
            Matrix_A_22[i] = Matrix_A[i + Middle] + Middle;
            
            Matrix_B_11[i] = Matrix_B[i];            
            Matrix_B_12[i] = Matrix_B[i] + Middle;            
            Matrix_B_21[i] = Matrix_B[i + Middle];            
            Matrix_B_22[i] = Matrix_B[i + Middle] + Middle;
            
            Matrix_C_11[i] = Matrix_C[i];            
            Matrix_C_12[i] = Matrix_C[i] + Middle;            
            Matrix_C_21[i] = Matrix_C[i + Middle];            
            Matrix_C_22[i] = Matrix_C[i + Middle] + Middle;        
        }


        /* ————————————————————————————————————————————————————————————         
         *          
         * 递归进行矩阵的乘法运算         
         *          
         * 在递归之前,我们会调用Add和Subtract两个函数进行矩阵之间必         
         * 要的加减法运算,并用Result_1和Result_2临时存储其结果。         
         *          
         * ———————————————————————————————————————————————————————————— */

        // M1        
	Add(Matrix_A_11, Matrix_A_22, Result_1, Middle);        
        Add(Matrix_B_11, Matrix_B_22, Result_2, Middle);        
        STRASSEN_ALGORITHM(Result_1, Result_2, M1, Middle);
        
        // M2        
        Add(Matrix_A_21, Matrix_A_22, Result_1, Middle);        
        STRASSEN_ALGORITHM(Result_1, Matrix_B_11, M2, Middle);
        
        // M3        
        Subtract(Matrix_B_12, Matrix_B_22, Result_1, Middle);        
        STRASSEN_ALGORITHM(Matrix_A_11, Result_1, M3, Middle);
        
        // M4        
        Subtract(Matrix_B_21, Matrix_B_11, Result_1, Middle);        
        STRASSEN_ALGORITHM(Matrix_A_22, Result_1, M4, Middle);
        
        // M5        
        Add(Matrix_A_11, Matrix_A_12, Result_1, Middle);        
        STRASSEN_ALGORITHM(Result_1, Matrix_B_22, M5, Middle);
        
        // M6        
        Subtract(Matrix_A_21, Matrix_A_11, Result_1, Middle);        
        Add(Matrix_B_11, Matrix_B_12, Result_2, Middle);        
        STRASSEN_ALGORITHM(Result_1, Result_2, M6, Middle);
        
        // M7        
        Subtract(Matrix_A_12, Matrix_A_22, Result_1, Middle);        
        Add(Matrix_B_21, Matrix_B_22, Result_2, Middle);        
        STRASSEN_ALGORITHM(Result_1, Result_2, M7, Middle);


        /* ————————————————————————————————————————————————————————————         
         *         
         * 现在我们将按照Strassen算法的步骤,将前面递归计算出来的矩         
         * 阵M1...M7进行特定的加减法运算,从而得出Matrix_C_11, Matrix_C_12         
         * Matrix_C_21, Matrix_C_22。         
         *          
         * 我们会调用Add和Subtract两个函数进行矩阵之间必要的加减法运         
         * 算,并用Result_1和Result_2临时存储其结果。         
         *          
         * ———————————————————————————————————————————————————————————— */

	// Matrix_C_11        
	Add(M1, M4, Result_1, Middle);        
	Subtract(Result_1, M5, Result_2, Middle);        
	Add(Result_2, M7, Matrix_C_11, Middle);
	
        // Matrix_C_12        
        Add(M3, M5, Matrix_C_12, Middle);
        
        // Matrix_C_21        
        Add(M2, M4, Matrix_C_21, Middle);
        
        // Matrix_C_22        
        Subtract(M1, M2, Result_1, Middle);        
        Add(Result_1, M3, Result_2, Middle);        
        Add(Result_2, M6, Matrix_C_22, Middle);

        /* ————————————————————————————————————————————————————————————         
         *          
         * 最后释放掉申请的动态内存         
         *          
         * ———————————————————————————————————————————————————————————— */                
        for (int i = 0; i < Middle; i++)        
        {
            delete[] M1[i];            
            delete[] M2[i];            
            delete[] M3[i];            
            delete[] M4[i];            
            delete[] M5[i];            
            delete[] M6[i];            
            delete[] M7[i];
            
            delete[] Result_1[i];            
            delete[] Result_2[i];
        }
        delete[] Matrix_A_11;        
        delete[] Matrix_A_12;        
        delete[] Matrix_A_21;        
        delete[] Matrix_A_22;
        
        delete[] Matrix_B_11;        
        delete[] Matrix_B_12;        
        delete[] Matrix_B_21;        
        delete[] Matrix_B_22;
        
        delete[] Matrix_C_11;        
        delete[] Matrix_C_12;        
        delete[] Matrix_C_21;        
        delete[] Matrix_C_22;
        
        delete[] M1;        
        delete[] M2;        
        delete[] M3;        
        delete[] M4;        
        delete[] M5;        
        delete[] M6;        
        delete[] M7;
        
        delete[] Result_1;        
        delete[] Result_2;    
    }    
}

测试代码

#include <iostream>

constexpr int N = 4;

// 声明创建矩阵函数,返回指向指针的指针
int ** Creating_Matrix(int length);

// 声明销毁矩阵函数
void Delete_Matrix(int **Matrix, int length);

// 声明矩阵加法函数
void Add(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length);

// 声明矩阵加法函数
void Subtract(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length);

// 声明Strassen算法(矩阵乘法函数)
void STRASSEN_ALGORITHM(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length);


int main(void)
{
    int **A, **B, **C;

    A = Creating_Matrix(N);
    B = Creating_Matrix(N);
    C = Creating_Matrix(N);

    int Number = 1;
    for (int i = 0; i < N; i++)
    {
	for (int j = 0; j < N; j++)
	{
		A[i][j] = Number;
		B[i][j] = Number + N * N;
	}
    }

    std::cout << "矩阵A: " << std::endl;
    for (int i = 0; i < N; i++)
    {
	for (int j = 0; j < N; j++)
		std::cout << A[i][j] << "\t\t";
	std::cout << std::endl;
    }
    std::cout << std::endl;

    std::cout << "矩阵B: " << std::endl;
    for (int i = 0; i < N; i++)
    {
	for (int j = 0; j < N; j++)
		std::cout << B[i][j] << "\t\t";
	std::cout << std::endl;
    }
    std::cout << std::endl;

    STRASSEN_ALGORITHM(A, B, C, N);

    std::cout << "矩阵C = A * B: " << std::endl;
    for (int i = 0; i < N; i++)
    {
	for (int j = 0; j < N; j++)
		std::cout << C[i][j] << "\t\t";
	std::cout << std::endl;
    }

    Delete_Matrix(A, N);
    Delete_Matrix(B, N);
    Delete_Matrix(C, N);

    return 0;
}


// 定义创建矩阵函数
int ** Creating_Matrix(int length)
{
    int **Temporary = new int *[length];
	
    for (int i = 0; i < length; i++)
	Temporary[i] = new int[length];
	
    return Temporary;
}

// 声明销毁矩阵函数
void Delete_Matrix(int **Matrix, int length)
{
    for (int i = 0; i < length; i++)
	delete[] Matrix[i];
	
    delete[] Matrix;
}

// 定义矩阵加法函数
void Add(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
    for (int i = 0; i < length; i++)        
        for (int j = 0; j < length; j++)            
            Matrix_C[i][j] = Matrix_A[i][j] + Matrix_B[i][j];
}

// 定义矩阵减法函数
void Subtract(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
    for (int i = 0; i < length; i++)        
        for (int j = 0; j < length; j++)            
            Matrix_C[i][j] = Matrix_A[i][j] - Matrix_B[i][j];
}

 
// 定义Strassen算法(矩阵乘法函数)
void STRASSEN_ALGORITHM(int **Matrix_A, int **Matrix_B, int **Matrix_C, int length)
{
    if (length == 1)        
        Matrix_C[0][0] = Matrix_A[0][0] * Matrix_B[0][0];    
    else    
    {        
        int Middle = length / 2;
        
        
        /* ————————————————————————————————————————————————————————————         
         *         
         * 因为我们传递二维数组得时候,必须指明其一维数组得下标,这         
         * 会导致我们得递归 length / 2 得规模时无法进行传递参数。         
         *          
         * 所以,我们将要声明指向指针得指针,他们将指向指针数组,指针         
         * 数组内得元素为指针,所以在将元素指向一个一维数组,用来模拟         
         * 二维数组。         
         *          
         * ———————————————————————————————————————————————————————————— */
         
        int **Matrix_A_11 = new int *[Middle];        
        int **Matrix_A_12 = new int *[Middle];        
        int **Matrix_A_21 = new int *[Middle];        
        int **Matrix_A_22 = new int *[Middle];
        
        int **Matrix_B_11 = new int *[Middle];        
        int **Matrix_B_12 = new int *[Middle];        
        int **Matrix_B_21 = new int *[Middle];        
        int **Matrix_B_22 = new int *[Middle];
        
        int **Matrix_C_11 = new int *[Middle];        
        int **Matrix_C_12 = new int *[Middle];        
        int **Matrix_C_21 = new int *[Middle];        
        int **Matrix_C_22 = new int *[Middle];
        
        int **M1 = new int *[Middle];        
        int **M2 = new int *[Middle];        
        int **M3 = new int *[Middle];        
        int **M4 = new int *[Middle];        
        int **M5 = new int *[Middle];        
        int **M6 = new int *[Middle];        
        int **M7 = new int *[Middle];
        
        int **Result_1 = new int *[Middle];        
        int **Result_2 = new int *[Middle];
        
        for (int i = 0; i < Middle; i++)        
        {
            Matrix_A_11[i] = new int [Middle];            
            Matrix_A_12[i] = new int [Middle];            
            Matrix_A_21[i] = new int [Middle];            
            Matrix_A_22[i] = new int [Middle];
            
            Matrix_B_11[i] = new int [Middle];            
            Matrix_B_12[i] = new int [Middle];            
            Matrix_B_21[i] = new int [Middle];            
            Matrix_B_22[i] = new int [Middle];
            
            Matrix_C_11[i] = new int [Middle];            
            Matrix_C_12[i] = new int [Middle];            
            Matrix_C_21[i] = new int [Middle];            
            Matrix_C_22[i] = new int [Middle];
            
            M1[i] = new int [Middle];            
            M2[i] = new int [Middle];            
            M3[i] = new int [Middle];            
            M4[i] = new int [Middle];            
            M5[i] = new int [Middle];            
            M6[i] = new int [Middle];            
            M7[i] = new int [Middle];
            
            Result_1[i] = new int [Middle];            
            Result_2[i] = new int [Middle];
        }
        
        
        /* ————————————————————————————————————————————————————————————————————————————————————————————————————————————         
         *         
         * 现在,我们将要给这些 "二维数组" 传递特地得值:         
         *          
         * Matrix_A_11将获取Matrix_A中 (0 ~ Middle - 1) * (0 ~ Middle - 1) 的数组下标的元素;         
         *          
         * Matrix_A_12将获取Matrix_A中 (0 ~ Middle - 1) * (Middle ~ length - 1) 的数组下标的元素;         
         *          
         * Matrix_A_21将获取Matrix_A中 (Middle ~ length - 1) * (0 ~ Middle - 1) 的数组下标的元素;         
         *          
         * Matrix_A_22将获取Matrix_A中 (Middle ~ length - 1) * (Middle ~ length - 1) 的数组下标的元素。         
         *          * 获取Matrix_B数组元素时,同理。         
         *          
         * ———————————————————————————————————————————————————————————————————————————————————————————————————————————— */
        
	for (int i = 0; i < Middle; i++)        
        {            
            for (int j = 0; j < Middle; j++)            
            {                
                Matrix_A_11[i][j] = Matrix_A[i][j];                
                Matrix_A_12[i][j] = Matrix_A[i][j + Middle];                
                Matrix_A_21[i][j] = Matrix_A[i + Middle][j];                
                Matrix_A_22[i][j] = Matrix_A[i + Middle][j + Middle];
                Matrix_B_11[i][j] = Matrix_B[i][j];                
                Matrix_B_12[i][j] = Matrix_B[i][j + Middle];                
                Matrix_B_21[i][j] = Matrix_B[i + Middle][j];                
                Matrix_B_22[i][j] = Matrix_B[i + Middle][j + Middle];            
             }        
         }


        /* ————————————————————————————————————————————————————————————         
         *          
         * 递归进行矩阵的乘法运算         
         *          
         * 在递归之前,我们会调用Add和Subtract两个函数进行矩阵之间必         
         * 要的加减法运算,并用Result_1和Result_2临时存储其结果。         
         *          
         * ———————————————————————————————————————————————————————————— */
        
        // M1        
        Add(Matrix_A_11, Matrix_A_22, Result_1, Middle);        
        Add(Matrix_B_11, Matrix_B_22, Result_2, Middle);        
        STRASSEN_ALGORITHM(Result_1, Result_2, M1, Middle);
        
        // M2        
        Add(Matrix_A_21, Matrix_A_22, Result_1, Middle);        
        STRASSEN_ALGORITHM(Result_1, Matrix_B_11, M2, Middle);
        
        // M3        
        Subtract(Matrix_B_12, Matrix_B_22, Result_1, Middle);        
        STRASSEN_ALGORITHM(Matrix_A_11, Result_1, M3, Middle);
        
        // M4        
        Subtract(Matrix_B_21, Matrix_B_11, Result_1, Middle);        
        STRASSEN_ALGORITHM(Matrix_A_22, Result_1, M4, Middle);
        
        // M5        
        Add(Matrix_A_11, Matrix_A_12, Result_1, Middle);        
        STRASSEN_ALGORITHM(Result_1, Matrix_B_22, M5, Middle);
        
        // M6        
        Subtract(Matrix_A_21, Matrix_A_11, Result_1, Middle);        
        Add(Matrix_B_11, Matrix_B_12, Result_2, Middle);        
        STRASSEN_ALGORITHM(Result_1, Result_2, M6, Middle);
        
        // M7        
        Subtract(Matrix_A_12, Matrix_A_22, Result_1, Middle);        
        Add(Matrix_B_21, Matrix_B_22, Result_2, Middle);        
        STRASSEN_ALGORITHM(Result_1, Result_2, M7, Middle);


        /* ————————————————————————————————————————————————————————————         
         *         
         * 现在我们将按照Strassen算法的步骤,将前面递归计算出来的矩         
         * 阵M1...M7进行特定的加减法运算,从而得出Matrix_C_11, Matrix_C_12         
         * Matrix_C_21, Matrix_C_22。         
         *          
         * 我们会调用Add和Subtract两个函数进行矩阵之间必要的加减法运         
         * 算,并用Result_1和Result_2临时存储其结果。         
         *          
         * ———————————————————————————————————————————————————————————— */

	// Matrix_C_11        
        Add(M1, M4, Result_1, Middle);        
        Subtract(Result_1, M5, Result_2, Middle);        
        Add(Result_2, M7, Matrix_C_11, Middle);
        
        // Matrix_C_12        
        Add(M3, M5, Matrix_C_12, Middle);
        
        // Matrix_C_21        
        Add(M2, M4, Matrix_C_21, Middle);
        
        // Matrix_C_22        
        Subtract(M1, M2, Result_1, Middle);        
        Add(Result_1, M3, Result_2, Middle);        
        Add(Result_2, M6, Matrix_C_22, Middle);


        /* ————————————————————————————————————————————————————————————         
         *          
         * 现在我们需要将前面的四个小矩阵"拼"回一个大矩阵         
         *          
         * ———————————————————————————————————————————————————————————— */
                 
	for (int i = 0; i < Middle; i++)        
	{            
	    for (int j = 0; j < Middle; j++)            
	    {
                Matrix_C[i][j] = Matrix_C_11[i][j];                
                Matrix_C[i][j + Middle] = Matrix_C_12[i][j];                
                Matrix_C[i + Middle][j] = Matrix_C_21[i][j];                
                Matrix_C[i + Middle][j + Middle] = Matrix_C_22[i][j];                            
            }        
	}


        /* ————————————————————————————————————————————————————————————         
         *          
         * 最后释放掉申请的动态内存         
         *          
         * ———————————————————————————————————————————————————————————— */                 
	for (int i = 0; i < Middle; i++)         
	{            
	    delete[] Matrix_A_11[i];            
	    delete[] Matrix_A_12[i];            
	    delete[] Matrix_A_21[i];            
	    delete[] Matrix_A_22[i];
    
            delete[] Matrix_B_11[i];            
            delete[] Matrix_B_12[i];            
            delete[] Matrix_B_21[i];            
            delete[] Matrix_B_22[i];
            
            delete[] Matrix_C_11[i];            
            delete[] Matrix_C_12[i];            
            delete[] Matrix_C_21[i];            
            delete[] Matrix_C_22[i];
            
            delete[] M1[i];            
            delete[] M2[i];            
            delete[] M3[i];            
            delete[] M4[i];            
            delete[] M5[i];            
            delete[] M6[i];            
            delete[] M7[i];
            
            delete[] Result_1[i];            
            delete[] Result_2[i];
        }
        
	delete[] Matrix_A_11;            
	delete[] Matrix_A_12;            
	delete[] Matrix_A_21;            
	delete[] Matrix_A_22;
	
	delete[] Matrix_B_11;            
	delete[] Matrix_B_12;            
	delete[] Matrix_B_21;            
	delete[] Matrix_B_22;

        delete[] Matrix_C_11;            
        delete[] Matrix_C_12;            
        delete[] Matrix_C_21;            
        delete[] Matrix_C_22;
        
        delete[] M1;            
        delete[] M2;            
        delete[] M3;            
        delete[] M4;            
        delete[] M5;            
        delete[] M6;            
        delete[] M7;
        
        delete[] Result_1;            
        delete[] Result_2;
        
    }    
}

/* ————————————————————————————————————————————————————————————     
 *  
 * 输出结果: 
 *  
 * 矩阵A:  
 * 1               2               3               4 
 * 5               6               7               8 
 * 9               10              11              12 
 * 13              14              15              16 
 *  
 * 矩阵B:  
 * 17              18              19              20 
 * 21              22              23              24 
 * 25              26              27              28 
 * 29              30              31              32 
 *  
 * 矩阵C = A * B: 
 * 250             260             270             280 
 * 618             644             670             696 
 * 986             1028            1070            1112 
 * 1354            1412            1470            1528 
 *  
 * ———————————————————————————————————————————————————————————— */
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值