Strassen矩阵乘法,通过把乘法次数从8次减少到了7次,从而达到了算法的优化(当然,这只是在2*2的情况下哈),递归方程为T(n) = 7*T(n/2) + O(n^2) (n > 2),优于传统的划分T(n) = 8*T(n/2) + O(n^2) (n > 2)。
调试了n久才出来的程序,贴之如下:
//Matrix.h
#ifndef MATRIX_H
#define MATRIX_H
class Matrix
{
public:
int ** M;
int n;//矩阵的维数是n*n的。
public:
Matrix(int dem);
Matrix(Matrix &Copy);
~Matrix();
void ShowNumber();//显示矩阵的数据内容
bool DivideIntoFourParts(Matrix &A11,Matrix &A12,Matrix &A21,Matrix &A22);//划分成四个部分
bool UnionTogether(Matrix &A11,Matrix &A12,Matrix &A21,Matrix &A22);//用四个部分合成
Matrix & operator = (Matrix &A);//等号重载
//friend Matrix & operator = (Matrix &A);
friend Matrix Add(Matrix &A,Matrix &B);//矩阵的加法,可以继续完善成加法的重载,姑且用Add替代一下
friend Matrix Sub(Matrix &A,Matrix &B);//减法
friend Matrix Multiply(Matrix &A,Matrix &B);//矩阵乘法
};
#endif
通过一个矩阵类来封装所有的加减乘操作,当然了,还有就是等号的重载,调试过程中遇到了一个问题:那就是类中成员函数的返回值类型问题,当返回的是Matrix的引用和Matrix对象的时候,程序中出现了不同的问题,于是,这个问题有待继续解决和了解!
然后就是Matrix类的实现部分和几个友元函数的实现:
//Matrix.cpp
#include"Matrix.h"
#include
using namespace std;
Matrix::Matrix(int dem)
{
n = dem;
M = new int*[n+1];
for(int i = 1; i <= n; i ++)
M[i] = new int[n+1];
}
Matrix::Matrix(Matrix &Copy)
{
n = Copy.n;
M = new int*[n+1];
for(int i = 1; i <= n; i ++){
M[i] = new int[n+1];
for(int j = 1; j <= n; j ++)
M[i][j] = Copy.M[i][j];
}
}
Matrix::~Matrix()
{
for(int i = 1; i <= n; i ++)
delete M[i];
delete M;
}
void Matrix::ShowNumber()
{
for(int i = 1; i <= n; i ++){
for(int j = 1; j <= n; j ++)
cout << M[i][j] << " ";
cout <<endl;
}
cout <<endl;
}
bool Matrix::DivideIntoFourParts(
Matrix &A11,Matrix &A12,Matrix &A21,Matrix &A22
){
if(n%2 != 0)return false;
int m = n/2;
int i,j;
for(i = 1; i <= m; i ++)
for(j = 1; j <= m; j ++)
A11.M[i][j] = M[i][j];
for(i = 1; i <= m; i ++)
for(j = m+1; j <= n;j ++)
A12.M[i][j-m] = M[i][j];
for(i = m+1; i <= n; i ++)
for(j = 1; j <= m; j ++)
A21.M[i-m][j] = M[i][j];
for(i = m+1; i <= n; i ++)
for(j = m+1; j <= n; j ++)
A22.M[i-m][j-m] = M[i][j];
return true;
}
bool Matrix::UnionTogether(
Matrix &A11,Matrix &A12,Matrix &A21,Matrix &A22
){
if(n%2 != 0)return false;
int m = n/2;
int i,j;
for(i = 1; i <= m; i ++)
for(j = 1; j <= m; j ++)
M[i][j] = A11.M[i][j];
for(i = 1; i <= m; i ++)
for(j = m+1; j <= n;j ++)
M[i][j] = A12.M[i][j-m];
for(i = m+1; i <= n; i ++)
for(j = 1; j <= m; j ++)
M[i][j] = A21.M[i-m][j];
for(i = m+1; i <= n; i ++)
for(j = m+1; j <= n; j ++)
M[i][j] = A22.M[i-m][j-m];
return true;
}
Matrix Add(Matrix &A,Matrix &B){
if(A.n != B.n){
cout << "矩阵维数不同,不能相加!" <<endl;
exit(0);
}
else{
Matrix res(A.n);
for(int i = 1; i <= A.n; i ++)
for(int j = 1; j <= A.n; j++)
res.M[i][j] = A.M[i][j] + B.M[i][j];
return res;
}
}
Matrix Sub(Matrix &A,Matrix &B){
if(A.n != B.n){
cout << "矩阵维数不同,不能相减!" <<endl;
exit(0);
}
else{
Matrix res(A.n);
for(int i = 1; i <= A.n; i ++)
for(int j = 1; j <= A.n; j++)
res.M[i][j] = A.M[i][j] - B.M[i][j];
return res;
}
}
Matrix Multiply(Matrix &A,Matrix &B){
if(A.n != B.n){
cout << "矩阵维数不同,不能相乘!" <<endl;
exit(0);
}
else{
Matrix res(A.n);
for(int i = 1; i <= A.n; i ++)
for(int j = 1; j <= A.n; j++){
res.M[i][j] = 0;
for(int k = 1; k <= A.n; k ++)
res.M[i][j] += A.M[i][k] *B.M[k][j];
}
return res;
}
}
Matrix & Matrix::operator = (Matrix &A){
for(int i = 1; i <= A.n; i ++)
for(int j = 1; j <= A.n; j ++)
M[i][j] = A.M[i][j];
//cout << "调用了=重载函数"<<endl;
return *this;
}
当然了,最重要的还是主函数里面的StrassenMatrixMultiply函数的递归调用,代码如下:
//main.cpp
#include
#include"Matrix.h"
using namespace std;
Matrix StrassenMatrixMultiply(Matrix & A,Matrix & B){
if(A.n != B.n){
cout << "矩阵规模不匹配!" <<endl;
exit(0);
}
else{
int n = A.n;
int m = n/2;
Matrix res(n);
if(n == 2){
res = Multiply(A,B);
return res;
}
else{//还需要继续划分
Matrix A11(m),A12(m),A21(m),A22(m),
B11(m),B12(m),B21(m),B22(m);
Matrix M1(m),M2(m),M3(m),
M4(m),M5(m),M6(m),M7(m);
Matrix C11(m),C12(m),C21(m),C22(m);
//对矩阵A、B进行划分操作
A.DivideIntoFourParts(A11,A12,A21,A22);
B.DivideIntoFourParts(B11,B12,B21,B22);
M1 = StrassenMatrixMultiply(A11,Sub(B12,B22));//递归调用
M2 = StrassenMatrixMultiply(Add(A11,A12),B22);
M3 = StrassenMatrixMultiply(Add(A21,A22),B11);
M4 = StrassenMatrixMultiply(A22,Sub(B21,B11));
M5 = StrassenMatrixMultiply(Add(A11,A22),Add(B11,B22));
M6 = StrassenMatrixMultiply(Sub(A12,A22),Add(B21,B22));
M7 = StrassenMatrixMultiply(Sub(A11,A21),Add(B11,B12));
//M7.ShowNumber();
C11 = Add(Sub(Add(M5,M4),M2),M6);
C12 = Add(M1,M2);
C21 = Add(M3,M4);
C22 = Sub(Sub(Add(M5,M1),M3),M7);
res.UnionTogether(C11,C12,C21,C22);
return res;
}
}
}
int main()
{
int n;//矩阵大小n*n
cout << "请输入矩阵规模n(假设n是2的幂):" <<endl;
cin >> n;
Matrix A(n),B(n),C(n);
int i,j,m = n/2;
//输入A、B
for(i = 1; i <= n; i ++)
for(j = 1;j <= n; j ++)
cin >> A.M[i][j];
for(i = 1; i <= n; i ++)
for(j = 1;j <= n; j ++)
cin >> B.M[i][j];
//显示
cout << endl << "显示输入的矩阵A、B数据:" <<endl;
A.ShowNumber();
B.ShowNumber();
cout << "****************************************"<<endl;
cout << "计算结果(A*B)如下:" <<endl<<endl;
StrassenMatrixMultiply(A,B).ShowNumber();
return 0;
}
好了,就这样吧!第一次贴...见谅!