Strassen矩阵乘法

矩阵的乘法法则

设 A = ( a i j ) , B = ( b i j ) 是 n × n 的方阵,则对 i , j = 1 , 2 , . . . , n ,定义 A 和 B 的乘积矩阵 C 中元素 c i j 如下 c i j = ∑ k = 1 n a i k ⋅ b k j 设A=(a_{ij}),B=(b_{ij})是n\times n 的方阵,则对i,j=1,2,...,n,定义A和B的乘积矩阵C中元素c_{ij}如下 \\ c_{ij}=\sum_{k=1}^{n}a_{ik}\cdot b_{kj} A=(aij)B=(bij)n×n的方阵,则对i,j=1,2,...,n,定义AB的乘积矩阵C中元素cij如下cij=k=1naikbkj
按照定义来计算,我们需要计算n^2个矩阵元素,每个元素是n个值的和(从k=1到n),也就是说,要每计算一个C矩阵的元素,就要进行n次乘法和n-1次加法。

#include<iostream>
using namespace std;

int main(){
    int n=0;
    cin>>n;
    int** matrixA=new int* [n]; //A矩阵
    int** matrixB=new int* [n]; //B矩阵
    int** matrixC=new int* [n];  //C矩阵

    for(int i=0;i<n;i++){
        matrixA[i]=new int [n];
        matrixB[i]=new int [n];
        matrixC[i]=new int [n];
    }
 
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++)
            cin>>matrixA[i][j];
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++)
            cin>>matrixB[i][j];
   //三重for循环实现矩阵乘积公式
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++){
            matrixC[i][j]=0;
            for(int k=0;k<n;k++)
                matrixC[i][j]=matrixC[i][j]+matrixA[i][k]*matrixB[k][j];
        }


    for(int i=0;i<n;i++){
        
        for(int j=0;j<n;j++)
            cout<<matrixC[i][j]<<" ";
        cout<<endl;
    }
    
}

显然依照定义来计算的时间复杂度将达到O(n^3)。

Strassen算法

假设三个矩阵均为n*n的矩阵,其中n为2的幂。那么可以将n*n的矩阵划分为4个n/2*n/2的子矩阵(不影响结果,详细见线性代数的矩阵分块法)
A = [ A 11 A 12 A 21 A 22 ] , B = [ B 11 B 12 B 21 B 22 ] , C = [ C 11 C 12 C 21 C 22 ] A=\begin{bmatrix} A_{11}&A_{12} \\ A_{21}&A_{22}\end{bmatrix},B=\begin{bmatrix}B_{11}&B_{12} \\B_{21}&B_{22}\end{bmatrix},C=\begin{bmatrix}C_{11}&C_{12}\\C_{21}&C_{22}\end{bmatrix} A=[A11A21A12A22],B=[B11B21B12B22],C=[C11C21C12C22]
因此可以将C=A*B,写成
[ C 11 C 12 C 21 C 22 ] = [ A 11 A 12 A 21 A 22 ] ⋅ [ B 11 B 12 B 21 B 22 ] \begin{bmatrix}C_{11}&C_{12}\\C_{21}&C_{22}\end{bmatrix}=\begin{bmatrix} A_{11}&A_{12} \\ A_{21}&A_{22}\end{bmatrix} \cdot \begin{bmatrix}B_{11}&B_{12} \\B_{21}&B_{22}\end{bmatrix} [C11C21C12C22]=[A11A21A12A22][B11B21B12B22]
也就是
C 11 = A 11 ⋅ B 11 + A 12 ⋅ B 21 C 12 = A 11 ⋅ B 12 + A 12 ⋅ B 22 C 21 = A 21 ⋅ B 11 + A 22 ⋅ B 21 C 22 = A 21 ⋅ B 22 + A 12 ⋅ B 22 C_{11}=A_{11}\cdot B_{11}+A_{12}\cdot B_{21} \\ C_{12}=A_{11}\cdot B_{12}+A_{12}\cdot B_{22} \\ C_{21}=A_{21}\cdot B_{11}+A_{22}\cdot B_{21} \\ C_{22}=A_{21}\cdot B_{22}+A_{12}\cdot B_{22} \\ C11=A11B11+A12B21C12=A11B12+A12B22C21=A21B11+A22B21C22=A21B22+A12B22
每个公式对应两对n/2*n/2矩阵的乘法和n/2*n/2的矩阵积的加法,也就是说应用分治思想来求解矩阵乘法,也就是应用矩阵分块法来求解。如果是一个二阶矩阵(4*4),那么只用分解一次就可以直接计算。
当矩阵的阶数大于2时,可以一直将矩阵分块,直到分解得到的子矩阵的阶数为2。由上面的公式,计算n阶方阵的乘积就转换成了计算8个n/2的方阵的乘积和4个n/2的方阵的加法。
T ( n ) = { O ( 1 ) 8 T ( n / 2 ) + O ( n 2 ) T(n)=\begin{cases} O(1) \\ \\ 8T(n/2)+O(n^2) \end{cases} T(n)= O(1)8T(n/2)+O(n2)
使用分治法直接计算矩阵乘法并不比原始定义更加有效,时间复杂度仍然为O(n^3)

#include<iostream>
using namespace std;

int** matrix_add(int** A,int** B,int n);
int** clone(int** A,int x1,int y1,int x2,int y2);
int** Strassen(int** A,int** B,int n);

int main(){
    int n=0;
    cin>>n;
    int** matrixA=new int* [n+1];
    int** matrixB=new int* [n+1];

    for(int i=1;i<=n;i++){
        matrixA[i]=new int [n+1];
        matrixB[i]=new int [n+1];
    }
 
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            cin>>matrixA[i][j];
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            cin>>matrixB[i][j];

    int** matrixC=Strassen(matrixA,matrixB,n);

    for(int i=1;i<=n;i++){
        
        for(int j=1;j<=n;j++)
            cout<<matrixC[i][j]<<" ";
        cout<<endl;
    }
    
}

int** Strassen(int** A,int** B,int n){
    int **C=new int* [n+1];
    for(int i=1;i<=n;i++)
        C[i]=new int[n+1];
    if(n==1)
        C[1][1]=A[1][1]*B[1][1];
    else{
        //分解矩阵
        /*
        【1 1】
        【1 1】
        n=2;
        A11=clone(1,1,1,1)
        A12=clone(1,2,1,2)
        
        */
        int **A11=clone(A,1,1,n/2,n/2);
        int **A12=clone(A,1,n/2+1,n/2,n);
        int **A21=clone(A,n/2+1,1,n,n/2);
        int **A22=clone(A,n/2+1,n/2+1,n,n);
        int **B11=clone(B,1,1,n/2,n/2);
        int **B12=clone(B,1,n/2+1,n/2,n);
        int **B21=clone(B,n/2+1,1,n,n/2);
        int **B22=clone(B,n/2+1,n/2+1,n,n);
        
        int **C11=matrix_add(Strassen(A11,B11,n/2),Strassen(A12,B21,n/2),n/2);
        int **C12=matrix_add(Strassen(A11,B12,n/2),Strassen(A12,B22,n/2),n/2);
        int **C21=matrix_add(Strassen(A21,B11,n/2),Strassen(A22,B21,n/2),n/2);
        int **C22=matrix_add(Strassen(A21,B12,n/2),Strassen(A22,B22,n/2),n/2);
        for(int i=1;i<=n;i++)
            for(int j=1;j<=n;j++){
                if(i<=n/2&&j<=n/2)
                    C[i][j]=C11[i][j];
                
                else if(i<=n/2&&j>n/2){
                    C[i][j]=C12[i][j-n/2];
                }
                else if(i>n/2&&j<=n/2){
                    C[i][j]=C21[i-n/2][j];
                }
                else{
                    C[i][j]=C22[i-n/2][j-n/2];
                }
            }
        

    }
    return C;

}
int** clone(int** A,int x1,int y1,int x2,int y2){
    int **temp=new int*[x2-x1+2];
    for(int i=x1;i<=x2;i++)
        temp[i-x1+1]=new int[x2-x1+2];
    for(int i=x1;i<=x2;i++){
        for(int j=y1;j<=y2;j++)
            temp[i-x1+1][j-y1+1]=A[i][j];
    }
    return temp;
}
int** matrix_add(int** A,int** B,int n){
    int **temp=new int*[n+1];
    for(int i=1;i<=n;i++)
        temp[i]=new int[n+1];
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            temp[i][j]=A[i][j]+B[i][j];
    
    return temp;
}

使用strassen算法优化后,对应公式写代码即可。

#include<iostream>
using namespace std;

int** matrix_add(int** A,int** B,int n);
int** clone(int** A,int x1,int y1,int x2,int y2);
int** Strassen(int** A,int** B,int n);
int** matrix_sub(int** A,int** B,int n);

int main(){
    int n=0;
    cin>>n;
    int** matrixA=new int* [n+1];
    int** matrixB=new int* [n+1];

    for(int i=1;i<=n;i++){
        matrixA[i]=new int [n+1];
        matrixB[i]=new int [n+1];
    }
 
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            cin>>matrixA[i][j];
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            cin>>matrixB[i][j];

    int** matrixC=Strassen(matrixA,matrixB,n);

    for(int i=1;i<=n;i++){
        
        for(int j=1;j<=n;j++)
            cout<<matrixC[i][j]<<" ";
        cout<<endl;
    }
    
}

int** Strassen(int** A,int** B,int n){
    int **C=new int* [n+1];
    for(int i=1;i<=n;i++)
        C[i]=new int[n+1];
    if(n==1)
        C[1][1]=A[1][1]*B[1][1];
    else{
        //分解矩阵 ,根据下标进行分解
        /*
        【1 1】
        【1 1】
        n=2;
        A11=clone(1,1,1,1)
        A12=clone(1,2,1,2)
        
        */
        int **A11=clone(A,1,1,n/2,n/2);
        int **A12=clone(A,1,n/2+1,n/2,n);
        int **A21=clone(A,n/2+1,1,n,n/2);
        int **A22=clone(A,n/2+1,n/2+1,n,n);
        int **B11=clone(B,1,1,n/2,n/2);
        int **B12=clone(B,1,n/2+1,n/2,n);
        int **B21=clone(B,n/2+1,1,n,n/2);
        int **B22=clone(B,n/2+1,n/2+1,n,n);

        int **S1=matrix_sub(B12,B22,n/2);
        int **S2=matrix_add(A11,A12,n/2);
        int **S3=matrix_add(A21,A22,n/2);
        int **S4=matrix_sub(B21,B11,n/2);
        int **S5=matrix_add(A11,A22,n/2);
        int **S6=matrix_add(B11,B22,n/2);
        int **S7=matrix_sub(A12,A22,n/2);
        int **S8=matrix_add(B21,B22,n/2);
        int **S9=matrix_sub(A11,A21,n/2);
        int **S10=matrix_add(B11,B12,n/2);
        
        /*
        int **C11=matrix_add(Strassen(A11,B11,n/2),Strassen(A12,B21,n/2),n/2);
        int **C12=matrix_add(Strassen(A11,B12,n/2),Strassen(A12,B22,n/2),n/2);
        int **C21=matrix_add(Strassen(A21,B11,n/2),Strassen(A22,B21,n/2),n/2);
        int **C22=matrix_add(Strassen(A21,B12,n/2),Strassen(A22,B22,n/2),n/2);
        */
       
       int** P1=Strassen(A11,S1,n/2);
       int** P2=Strassen(S2,B22,n/2);
       int** P3=Strassen(S3,B11,n/2);
       int** P4=Strassen(A22,S4,n/2);
       int** P5=Strassen(S5,S6,n/2);
       int** P6=Strassen(S7,S8,n/2);
       int** P7=Strassen(S9,S10,n/2);

       int **C11=matrix_add(matrix_sub(matrix_add(P5,P4,n/2),P2,n/2),P6,n/2); //C11=P5+P4-P2+P6;
       int **C12=matrix_add(P1,P2,n/2); //C12=P1+P2
       int **C21=matrix_add(P3,P4,n/2); //C21=P3+P4
       int **C22=matrix_sub(matrix_add(matrix_sub(P1,P3,n/2),P5,n/2),P7,n/2); //C22=P5+P1-P3-P7
       /*
       将分解后的乘积矩阵重新组装
       */
        for(int i=1;i<=n;i++)
            for(int j=1;j<=n;j++){
                if(i<=n/2&&j<=n/2)
                    C[i][j]=C11[i][j];
                
                else if(i<=n/2&&j>n/2){
                    C[i][j]=C12[i][j-n/2];
                }
                else if(i>n/2&&j<=n/2){
                    C[i][j]=C21[i-n/2][j];
                }
                else{
                    C[i][j]=C22[i-n/2][j-n/2];
                }
            }
        

    }
    return C;

}
int** clone(int** A,int x1,int y1,int x2,int y2){
    int **temp=new int*[x2-x1+2];
    for(int i=x1;i<=x2;i++)
        temp[i-x1+1]=new int[x2-x1+2];
    for(int i=x1;i<=x2;i++){
        for(int j=y1;j<=y2;j++)
            temp[i-x1+1][j-y1+1]=A[i][j];
    }
    return temp;
}
int** matrix_add(int** A,int** B,int n){
    int **temp=new int*[n+1];
    for(int i=1;i<=n;i++)
        temp[i]=new int[n+1];
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            temp[i][j]=A[i][j]+B[i][j];
    
    return temp;
}
int** matrix_sub(int** A,int** B,int n){
    int** temp=new int*[n+1];
    
    for(int i=1;i<=n;i++)
        temp[i]=new int[n+1];
    
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            temp[i][j]=A[i][j]-B[i][j];
    
    return temp;
}
/*
4
1 2 3 4
4 3 2 1
5 6 7 8
9 8 7 1
1 2 2 3
1 3 3 2
5 6 7 8
1 1 1 1
*/

Strassen算法运行时间T(n)的递归式
T ( n ) = { Θ ( 1 )   n = 1 7 T ( n / 2 ) + Θ ( n 2 )   n > 1 T(n)=\begin{cases} \Theta(1) \ n=1\\ \\ 7T(n/2)+ \Theta(n^2) \ n>1 \end{cases} T(n)= Θ(1) n=17T(n/2)+Θ(n2) n>1

  • 20
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值