SQUARE-MATRIX-MULTIPLY-RECURSIVE(A,B)
1 n = A.rows //A的行数
2 let C be a new n*n matrix //让C变成新的n*n矩阵
3 if n == 1
4 c11 = a11 * b11
5 else partition A,B,and C as in equations //将三个矩阵各自分成4个部分
//分别求出四个元素
6 C11 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B11)
+ SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B21)
7 C12 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B12)
+ SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B22)
8 C21 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B11)
+ SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B21)
9 C22 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B12)
+ SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B22)
10 return C
Strassen()
let C be a new n*n matrix
if A.row == 1:
C = A * B
else partition A,B,and C //步骤1:将四个矩阵各自分为四部分
//步骤2:计算10个S
S1=B12-B22
S2=A11-A12
S3=A21+A22
S4=B21-B11
S5=A11+A22
S6=B11+B22
S7=A12-A22
S8=B21+B22
S9=A11-A21
S10=B11+B12
//步骤3:递归计算7个矩阵积
P1=Strassen(A11,S1)
P2=Strassen(A11,B22)
P3=Strassen(S3,B11)
P4=Strassen(A22,S4)
P5=Strassen(S5,S6)
P6=Strassen(S7,S8)
P7=Strassen(S9,S10)
//步骤4:不同Pi的加减运算
C11=P5+P4-P2+P6
C12=P1+P2
C21=P3+P4
C22=P5+P1-P3-P7
return C
二、C++代码
#include<iostream>#include<Windows.h>usingnamespace std;template<typenameT>classStrassen_class{public:voidADD(T** MatrixA, T** MatrixB, T** MatrixResult,int MatrixSize);voidSUB(T** MatrixA, T** MatrixB, T** MatrixResult,int MatrixSize);voidMUL(T** MatrixA, T** MatrixB, T** MatrixResult,int MatrixSize);//朴素算法实现voidFillMatrix(T** MatrixA, T** MatrixB,int length);//A,B矩阵赋值voidPrintMatrix(T** MatrixA,int MatrixSize);//打印矩阵voidStrassen(int N, T** MatrixA, T** MatrixB, T** MatrixC);//Strassen算法实现};//矩阵相加template<typenameT>voidStrassen_class<T>::ADD(T** MatrixA, T** MatrixB, T** MatrixResult,int MatrixSize){for(int i =0; i < MatrixSize; i++){for(int j =0; j < MatrixSize; j++){
MatrixResult[i][j]= MatrixA[i][j]+ MatrixB[i][j];}}}//矩阵相减template<typenameT>voidStrassen_class<T>::SUB(T** MatrixA, T** MatrixB, T** MatrixResult,int MatrixSize){for(int i =0; i < MatrixSize; i++){for(int j =0; j < MatrixSize; j++){
MatrixResult[i][j]= MatrixA[i][j]- MatrixB[i][j];}}}//普通的矩阵乘法template<typenameT>voidStrassen_class<T>::MUL(T** MatrixA, T** MatrixB, T** MatrixResult,int MatrixSize){for(int i =0; i < MatrixSize; i++){for(int j =0; j < MatrixSize; j++){
MatrixResult[i][j]=0;for(int k =0; k < MatrixSize; k++){
MatrixResult[i][j]= MatrixResult[i][j]+ MatrixA[i][k]* MatrixB[k][j];}}}}//A、B矩阵赋值template<typenameT>voidStrassen_class<T>::FillMatrix(T** MatrixA, T** MatrixB,int length){for(int row =0; row < length; row++){for(int column =0; column < length; column++){//给矩阵里赋值0到4的随机数
MatrixB[row][column]=(MatrixA[row][column]=rand()%5);}}}//打印矩阵template<typenameT>voidStrassen_class<T>::PrintMatrix(T** MatrixA,int MatrixSize){
cout << endl;for(int row =0; row < MatrixSize; row++){for(int column =0; column < MatrixSize; column++){
cout << MatrixA[row][column]<<"\t";if((column +1)%((MatrixSize))==0)
cout << endl;}}
cout << endl;}//Strassen算法template<typenameT>voidStrassen_class<T>::Strassen(int N, T **MatrixA, T **MatrixB, T **MatrixC){int HalfSize = N /2;int newSize = N /2;//当不能分成4个4*4的数组时,我们就采用正常的办法if(N <=64){MUL(MatrixA, MatrixB, MatrixC, N);}else{//创建多个二维数组
T** A11; T** A12; T** A21; T** A22;
T** B11; T** B12; T** B21; T** B22;
T** C11; T** C12; T** C21; T** C22;
T** M1; T** M2; T** M3; T** M4;
T** M5; T** M6; T** M7;
T** AResult; T** BResult;//创建一个一维数组的指针,用于寻找首地址
A11 =new T *[newSize];
A12 =new T *[newSize];
A21 =new T *[newSize];
A22 =new T *[newSize];
B11 =new T *[newSize];
B12 =new T *[newSize];
B21 =new T *[newSize];
B22 =new T *[newSize];
C11 =new T *[newSize];
C12 =new T *[newSize];
C21 =new T *[newSize];
C22 =new T *[newSize];
M1 =new T *[newSize];
M2 =new T *[newSize];
M3 =new T *[newSize];
M4 =new T *[newSize];
M5 =new T *[newSize];
M6 =new T *[newSize];
M7 =new T *[newSize];
AResult =new T *[newSize];
BResult =new T *[newSize];int newLength = newSize;//N/2长度//在上面一维数组的基础上,分别在每一行再创建一个一维数组的指针,从而实现一个二维数组for(int i =0; i < newSize; i++){
A11[i]=new T[newLength];
A12[i]=new T[newLength];
A21[i]=new T[newLength];
A22[i]=new T[newLength];
B11[i]=new T[newLength];
B12[i]=new T[newLength];
B21[i]=new T[newLength];
B22[i]=new T[newLength];
C11[i]=new T[newLength];
C12[i]=new T[newLength];
C21[i]=new T[newLength];
C22[i]=new T[newLength];
M1[i]=new T[newLength];
M2[i]=new T[newLength];
M3[i]=new T[newLength];
M4[i]=new T[newLength];
M5[i]=new T[newLength];
M6[i]=new T[newLength];
M7[i]=new T[newLength];
AResult[i]=new T[newLength];
BResult[i]=new T[newLength];}//将输入的数组四等分成N/2*N/2的数组,将A和B中的数组各自赋值给自己的四个分支数组for(int i =0; i < N /2; i++){for(int j =0; j < N /2; j++){
A11[i][j]= MatrixA[i][j];
A12[i][j]= MatrixA[i][j + N /2];
A21[i][j]= MatrixA[i + N /2][j];
A22[i][j]= MatrixA[i + N /2][j + N /2];
B11[i][j]= MatrixB[i][j];
B12[i][j]= MatrixB[i][j + N /2];
B21[i][j]= MatrixB[i + N /2][j];
B22[i][j]= MatrixB[i + N /2][j + N /2];}}//计算7个矩阵//M1=A11(B12-B22) SUB(B12, B22, BResult, HalfSize);Strassen(HalfSize, A11, BResult, M1);//M2=(A11+A12)B22 ADD(A11, A12, AResult, HalfSize);Strassen(HalfSize, AResult, B22, M2);//M3=(A21+A22)B11 ADD(A21, A22, AResult, HalfSize);Strassen(HalfSize, AResult, B11, M3);//M4=A22(B21-B11) SUB(B21, B11, BResult, HalfSize);Strassen(HalfSize, A22, BResult, M4);//M5=(A11+A22)(B11+B22)ADD(A11, A22, AResult, HalfSize);ADD(B11, B22, BResult, HalfSize);Strassen(HalfSize, AResult, BResult, M5);//M6=(A12-A22)(B21+B22) SUB(A12, A22, AResult, HalfSize);ADD(B21, B22, BResult, HalfSize);Strassen(HalfSize, AResult, BResult, M6);//M7=(A11-A21)(B11+B12)SUB(A11, A21, AResult, HalfSize);ADD(B11, B12, BResult, HalfSize);Strassen(HalfSize, AResult, BResult, M6);//C11 = M5 + M4 - M2 + M6;ADD(M5, M4, AResult, HalfSize);SUB(M6, M2, BResult, HalfSize);ADD(AResult, BResult, C11, HalfSize);//C12 = M1 + M1;ADD(M1, M2, C12, HalfSize);//C21 = M3 + M4;ADD(M3, M4, C21, HalfSize);//C22 = M5 + M1 - M3 - M7;ADD(M5, M1, AResult, HalfSize);ADD(M7, M3, BResult, HalfSize);SUB(AResult, BResult, C22, HalfSize);//组合小矩阵到一个大矩阵for(int i =0; i < N /2; i++){for(int j =0; j < N /2; j++){
MatrixC[i][j]= C11[i][j];
MatrixC[i][j + N /2]= C12[i][j];
MatrixC[i + N /2][j]= C21[i][j];
MatrixC[i + N /2][j + N /2]= C22[i][j];}}// 释放矩阵内存空间for(int i =0; i < newLength; i++){delete[] A11[i];delete[] A12[i];delete[] A21[i];delete[] A22[i];delete[] B11[i];delete[] B12[i];delete[] B21[i];delete[] B22[i];delete[] C11[i];delete[] C12[i];delete[] C21[i];delete[] C22[i];delete[] M1[i];delete[] M2[i];delete[] M3[i];delete[] M4[i];delete[] M5[i];delete[] M6[i];delete[] M7[i];delete[] AResult[i];delete[] BResult[i];}delete[] A11;delete[] A12;delete[] A21;delete[] A22;delete[] B11;delete[] B12;delete[] B21;delete[] B22;delete[] C11;delete[] C12;delete[] C21;delete[] C22;delete[] M1;delete[] M2;delete[] M3;delete[] M4;delete[] M5;delete[] M6;delete[] M7;delete[] AResult;delete[] BResult;}}intmain(){
Strassen_class<int> stra;//定义Strassen_class类对象int MatrixSize =0;int** MatrixA;//存放矩阵Aint** MatrixB;//存放矩阵Bint** MatrixC;//存放结果矩阵
cout <<"\n请输入矩阵大小(必须是2的幂指数值(例如:32,64,512,..): ";
cin >> MatrixSize;int N = MatrixSize;//for readiblity.//申请内存
MatrixA =newint*[MatrixSize];
MatrixB =newint*[MatrixSize];
MatrixC =newint*[MatrixSize];//申请空间for(int i =0; i < MatrixSize; i++){
MatrixA[i]=newint[MatrixSize];
MatrixB[i]=newint[MatrixSize];
MatrixC[i]=newint[MatrixSize];}
stra.FillMatrix(MatrixA, MatrixB, MatrixSize);//矩阵赋值
stra.Strassen(N, MatrixA, MatrixB, MatrixC);//strassen矩阵相乘算法
cout <<"\n矩阵运算结果... \n";
stra.PrintMatrix(MatrixC, MatrixSize);return0;}