strassen算法首先要求两个矩阵要是方块矩阵,并且维数是2的幂。如果两个矩阵不满足这两个条件,可以增加元素均为0的行与列以满足条件。假设要计算C=A*B。
首先将A,B,C均分解为4个n/2*n/2的子矩阵。假设将A分解为A11,A12,A21,A22。 B分解为B11,B12, B21, B22。 对C做类似的操作。
创建10个n/2*n/2的矩阵 S1,S2,…, S10。
- S1=B12-B22
- S2=A11+A12
- S3=A21+A22
- S4=B21-B22
- S5=A11+A22
- S6=B11+B22
- S7=A12-A22
- S8=B21+B22
- S9=A11-A21
- S10=B11+B12
递归地计算7次n/2*n/2矩阵的乘法
- P1=A11*S1
- P2=S2*B22
- P3=S3*B11
- P4=A22*S4
- P5=S5*S6
- P6=S7*S8
- P7=S9*S10
对步骤3创建的7个矩阵进行加减运算,计算出C的4个子矩阵
- C11=P5+P4-P2+P6
- C12=P1+P2
- C21=P3+P4
- C22=P5+P1-P3-P7
代码如下:
#include <vector>
#include <stdexcept>
#include <cmath>
template<class Type>
void subMatrices(int lowRow,int highRow,int lowColumn,int highColumn,const vector< vector<Type> >& A,vector <vector<Type> >& subMatr)
{
subMatr.resize(highRow-lowRow);
for(int i=0;i!=subMatr.size();++i)
subMatr[i].resize(highColumn-lowColumn,Type());
for(int i=0;i!=subMatr.size();++i)
for(int j=0;j!=subMatr[i].size();++j)
subMatr[i][j]=A[lowRow+i][lowColumn+j];
}
template<class Type>
void matrixSubtract(const vector< vector<Type> >& A, const vector< vector<Type> >& B, vector< vector<Type> >& C)
{
if(A.size()==0||B.size()==0)
throw runtime_error("one of matices is empty!");
if(A.size()!=B.size()||A[0].size()!=B[0].size())
throw runtime_error("the dimensions1 of the two matrices are not same!");
C.resize(A.size());
for(int i=0;i!=C.size();++i)
C[i].resize(A[0].size(),Type());
for(int i=0;i!=A.size();++i)
for(int j=0;j!=A[i].size();++j)
C[i][j]=A[i][j]-B[i][j];
}
template<class Type>
void matrixAdd(const vector< vector<Type> >& A, const vector< vector<Type> >& B, vector< vector<Type> >& C)
{
if(A.size()==0||B.size()==0)
throw runtime_error("one of matices is empty!");
if(A.size()!=B.size()||A[0].size()!=B[0].size())
throw runtime_error("the dimensions22 of the two matrices are not same!");
C.resize(A.size());
for(int i=0;i!=C.size();++i)
C[i].resize(A[0].size(),Type());
for(int i=0;i!=A.size();++i)
for(int j=0;j!=A[i].size();++j)
C[i][j]=A[i][j]+B[i][j];
}
template<class Type>
void dimensionWithPowerOf2(vector< vector<Type> >& A,vector< vector<Type> >& B)
{
int maxDimension=A.size()<A[0].size()?A[0].size():A.size();
maxDimension=maxDimension<B[0].size()?B[0].size():maxDimension;
int index=0;
while(pow(2,index)<maxDimension)
index++;
maxDimension=pow(2,index);
A.resize(maxDimension);
for(int i=0;i!=maxDimension;++i)
A[i].resize(maxDimension,Type());
B.resize(maxDimension);
for(int i=0;i!=maxDimension;++i)
B[i].resize(maxDimension,Type());
}
template<class Type>
void strassenRecur(const vector< vector<Type> >& A, const vector< vector<Type> >& B,vector< vector<Type> >& C)
{
C.resize(A.size());
for(int i=0;i!=A.size();++i)
C[i].resize(A.size(),Type());
if(A.size()==128){
directMultiply(A,B,C);
return;
} // save much much more time than setting A.size()==1.
int mid=A.size()/2;
// equally divide matrix A to four sub matrix
vector< vector<Type> > A11;
subMatrices(0,mid,0,mid,A,A11);
vector< vector<Type> > A12;
subMatrices(0,mid,mid,A.size(),A,A12);
vector< vector<Type> > A21;
subMatrices(mid,A.size(),0,mid,A,A21);
vector< vector<Type> > A22;
subMatrices(mid,A.size(),mid,A.size(),A,A22);
// equally divide matrix B to four sub matrix
vector< vector<Type> > B11;
subMatrices(0,mid,0,mid,B,B11);
vector< vector<Type> > B12;
subMatrices(0,mid,mid,B.size(),B,B12);
vector< vector<Type> > B21;
subMatrices(mid,B.size(),0,mid,B,B21);
vector< vector<Type> > B22;
subMatrices(mid,B.size(),mid,B.size(),B,B22);
// calculate P1=A11*S1=A11*(B12-B22)
vector< vector<Type> > s1;
matrixSubtract(B12,B22,s1);
vector< vector<Type> > p1;
strassenRecur(A11,s1,p1);
// calculate P2=S2*B22=(A11+A12)*B22
vector< vector<Type> > s2;
matrixAdd(A11,A12,s2);
vector< vector<Type> > p2;
strassenRecur(s2,B22,p2);
// calculate P3=S3*B11=(A21+A22)*B11
vector< vector<Type> > s3;
matrixAdd(A21,A22,s3);
vector< vector<Type> > p3;
strassenRecur(s3,B11,p3);
// calculate P4=A22*S4=A22*(B21-B11)
vector< vector<Type> > s4;
matrixSubtract(B21,B11,s4);
vector< vector<Type> > p4;
strassenRecur(A22,s4,p4);
// calcualte P5=S5*S6=(A11+A22)*(B11+B22)
vector< vector<Type> > s5;
matrixAdd(A11,A22,s5);
vector< vector<Type> > s6;
matrixAdd(B11,B22,s6);
vector< vector<Type> > p5;
strassenRecur(s5,s6,p5);
// calculate P6=S7*S8=(A12+A22)*(B21+B22)
vector< vector<Type> > s7;
matrixSubtract(A12,A22,s7);
vector< vector<Type> > s8;
matrixAdd(B21,B22,s8);
vector< vector<Type> > p6;
strassenRecur(s7,s8,p6);
// calculate P7=S9*S10=(A11+A21)*(B11+B12)
vector< vector<Type> > s9;
matrixSubtract(A11,A21,s9);
vector< vector<Type> > s10;
matrixAdd(B11,B12,s10);
vector< vector<Type> > p7;
strassenRecur(s9,s10,p7);
vector< vector<Type> > tmp1;
vector< vector<Type> > tmp2;
// C11=P5+P4-P2=P6
matrixAdd(p5,p4,tmp1);
matrixSubtract(tmp1,p2,tmp2);
matrixAdd(tmp2,p6,tmp1);
for(int i=0;i!=mid;++i)
for(int j=0;j!=mid;++j)
C[i][j]=tmp1[i][j];
// C12=P1+P2
matrixAdd(p1,p2,tmp1);
for(int i=0;i!=mid;++i)
for(int j=0;j!=mid;++j)
C[i][j+mid]=tmp1[i][j];
// C21=P3+P4
matrixAdd(p3,p4,tmp1);
for(int i=0;i!=mid;++i)
for(int j=0;j!=mid;++j)
C[i+mid][j]=tmp1[i][j];
// C22=P5+P1-P3-P7
matrixAdd(p5,p1,tmp1);
matrixSubtract(tmp1,p3,tmp2);
matrixSubtract(tmp2,p7,tmp1);
for(int i=0;i!=mid;++i)
for(int j=0;j!=mid;++j)
C[i+mid][j+mid]=tmp1[i][j];
}
// to calculate C=A*B
template<class Type>
void strassen( vector< vector<Type> >& A,vector< vector<Type> >& B, vector< vector<Type> >& C)
{
if(A.size()==0 || B.size()==0)
throw runtime_error("one of the two matrices is empty!");
if(A[0].size()!=B.size())
throw runtime_error("The column number of the first matrice is unequal to the row number of the second matrice!");
C.resize(A.size());
for(int i=0;i!=C.size();++i)
C[i].resize(B[0].size(),Type());
// extend the dimensions of matrices A and B to power of 2. But A and B will be changed!.
dimensionWithPowerOf2(A,B);
// to calculate A*B(the dimensions of A and B are powers of 2)
vector<vector<Type> > multiplyResult;
strassenRecur(A,B,multiplyResult);
// to use C to store the actual result of A*B
for(int i=0;i!=C.size();++i)
for(int j=0;j!=C[i].size();++j)
C[i][j]=multiplyResult[i][j];
}