矩阵乘法

#include <iostream>
#include <ctime>
using namespace std;

#define SAFEDELETE(a) delete a;a = NULL;

struct Matrix
{
	Matrix() :row(0), col(0), data(NULL), boSub(false){};
	Matrix(int _row, int _col) :data(NULL), boSub(false)
	{
		SetData(_row,_col);
	}
	~Matrix()
	{
		DeleteData();
	}
	bool SetData(int _row, int _col)
	{
		if (data == NULL)
		{
			row = _row;
			col = _col;
			DeleteData();
			data = new int*[row];
			for (int i = 0; i < row; i++)
			{
				data[i] = new int[col];
				for (int j = 0; j < col; j++)
				{
					data[i][j] = 0;
				}
			}
		}
		return true;
	}
	bool DeleteData()
	{
		if (data != nullptr && boSub == false )
		{
			for (int i = 0; i < row; i++)
			{
				delete[] data[i];
				data[i] = nullptr;
			}
			delete[] data;
			data = nullptr;
		}
		return true;
	}

	friend ostream &operator << (ostream &cout,  Matrix &m)
	{
		if (m.data == NULL) return cout;

		for (int i = 0; i < m.row; i++)
		{
			for (int j = 0; j < m.col; j++)
			{
				cout << m.data[i][j] << " " ;
			}
			cout << endl;
		}
		return cout;
	}
	friend ostream &operator << (ostream &cout, Matrix *m)
	{
		if (m->data == NULL) return cout;

		for (int i = 0; i < m->row; i++)
		{
			for (int j = 0; j < m->col; j++)
			{
				cout << m->data[i][j] << " ";
			}
			cout << endl;
		}
		return cout;
	}
	bool RandData()
	{
		if (data == NULL) return false;

		for (int i = 0; i < row; i++)
		{
			for (int j = 0; j < col; j++)
			{
				data[i][j] = rand() % 10;
			}
		}
		return true;
	}


	Matrix *GetHalfSubMatrix(int type)
	{
		Matrix * m= new Matrix;
		m->data = new int*[this->row];
		m->boSub = true;

		m->row = this->row / 2;
		m->col = this->col / 2;

		m->data[0] = &this->data[0][0];
		//左上
		if (type == 11)
		{
			//m->data = this->data;  不能这么写
			m->data[0] = &this->data[0][0];
			for (int i = 0; i < m->row; i++)
			{
				m->data[i] = &this->data[i][0];
			}
		}
		else if (type == 12)
		{
			//右上
			for (int i = 0; i < m->row; i++)
			{
				m->data[i] = &this->data[i][m->col];
			}
		}
		else if (type == 21)
		{
			//左下
			for (int i = 0; i < m->col; i++)
			{
				m->data[i] = &this->data[m->row + i][0];
			}
		}
		else if (type == 22)
		{
			//右下
			for (int i = 0; i < m->col; i++)
			{
				m->data[i] = &this->data[m->row + i][ m->col ];
			}
		}
		return m;
	}

	bool SubRelease()
	{
		/*for (int i = 0; i < row; i++)
		{
			data[i] = NULL;
		}*/
		if (data != NULL)
		{
			delete[] data;
			data = NULL;
		}
		return true;
	}
	int row;
	int col; 
	int **data;
	bool boSub;
};

//矩阵乘法 暴力求解 n*n
void MatrixMul(Matrix *a, Matrix *b, Matrix *nResult)
{
	Matrix *first = NULL, *second = NULL;

	if (a->col == b->row){
		first = a;
		second = b;
		nResult->SetData(a->row, b->col);
	}
	else if (a->row == b->col)
	{
		first = b;
		second = a;
		nResult->SetData(b->row, a->col);
	}

	if (first == NULL || second == NULL)
	{
		cout << "MatrixMul Can`t Run" << endl; 
		return;
	}
		

	int *temp = new int[second->col];
	for (int i = 0; i < first->row; i++)
	{
		for (int j = 0; j < second->col; j++)
		{
			temp[j] = 0;
			//nResult->data[i][j] = 0;
			for (int k = 0; k < first->col; k++)
			{
				//nResult->data[i][j] += (first->data[i][k] * second->data[k][j]);
				temp[j] += (first->data[i][k] * second->data[k][j]);
			}
			//nResult->data[i][j] = temp;
		}
		for (int j = 0; j < second->col; j++)
		{
			nResult->data[i][j] = temp[j];
		}
	}
	delete[] temp;

}


void MatrixAdd(Matrix *a, Matrix *b, Matrix *c)
{
	if (a == NULL || b == NULL || c == NULL || a->data == NULL || b->data == NULL || c->data == NULL)return;

	for (int i = 0; i < a->row; i++)
	{
		for (int j = 0; j < a->col; j++)
		{
			c->data[i][j] = a->data[i][j] + b->data[i][j];
		}
	}
}
void MatrixSub(Matrix *a, Matrix *b, Matrix *c)
{
	if (a == NULL || b == NULL || c == NULL || a->data == NULL || b->data == NULL || c->data == NULL)return;

	for (int i = 0; i < a->row; i++)
	{
		for (int j = 0; j < a->col; j++)
		{
			c->data[i][j] = a->data[i][j] - b->data[i][j];
		}
	}
}
bool CheckMatrix(Matrix *a, Matrix *b, Matrix *nResult)
{
	Matrix *first = NULL, *second = NULL;

	if (a->col == b->row){
		first = a;
		second = b;
		nResult->SetData(a->row, b->col);
	}
	else if (a->row == b->col)
	{
		first = b;
		second = a;
		nResult->SetData(b->row, a->col);
	}

	if (first == NULL || second == NULL)
	{
		cout << "MatrixMul Can`t Run" << endl;
		return false;
	}
	return true;
}

Matrix * MatrixMulMarge(Matrix *a, Matrix *b)
{
	Matrix *nResult = new Matrix();
	CheckMatrix(a, b, nResult);

	Matrix *c11 = nResult->GetHalfSubMatrix(11);
	Matrix *c12 = nResult->GetHalfSubMatrix(12);
	Matrix *c21 = nResult->GetHalfSubMatrix(21);
	Matrix *c22 = nResult->GetHalfSubMatrix(22);

	if (a->row <= 1 || a->col <= 1)
	{
		MatrixMul(a, b, c11);
	}
	else{
		Matrix *a11 = a->GetHalfSubMatrix(11);
		Matrix *a12 = a->GetHalfSubMatrix(12);
		Matrix *a21 = a->GetHalfSubMatrix(21);
		Matrix *a22 = a->GetHalfSubMatrix(22);
		cout << "a = " << endl << a11 << " " << a12 << " " << a21 << " " << a22 << endl;

		Matrix *b11 = b->GetHalfSubMatrix(11);
		Matrix *b12 = b->GetHalfSubMatrix(12);
		Matrix *b21 = b->GetHalfSubMatrix(21);
		Matrix *b22 = b->GetHalfSubMatrix(22);

		cout << "b = " << endl << b11 << " " << b12 << " " << b21 << " " << b22 << endl;

		Matrix * temp1 = MatrixMulMarge(a11, b11);
		Matrix * temp2 = MatrixMulMarge(a12, b21);

		MatrixAdd(temp1, temp2,c11);
		SAFEDELETE(temp1);
		SAFEDELETE(temp2);

		 temp1 = MatrixMulMarge(a11, b12);
		 temp2 = MatrixMulMarge(a12, b22);
		 MatrixAdd(temp1, temp2, c12);
		 SAFEDELETE(temp1);
		 SAFEDELETE(temp2);

		temp1 = MatrixMulMarge(a21, b11);
		temp2 = MatrixMulMarge(a22, b21);
		MatrixAdd(temp1, temp2, c21);
		SAFEDELETE(temp1);
		SAFEDELETE(temp2);

		temp1 = MatrixMulMarge(a21, b12);
		temp2 = MatrixMulMarge(a22, b22);
		MatrixAdd(temp1, temp2, c22);
		SAFEDELETE(temp1);
		SAFEDELETE(temp2);

		cout << "c = " << nResult<< endl;
		a11->SubRelease();
		a12->SubRelease();
		a21->SubRelease();
		a22->SubRelease();

		b11->SubRelease();
		b12->SubRelease();
		b21->SubRelease();
		b22->SubRelease();
	}
	
	

	c11->SubRelease();
	c12->SubRelease();
	c21->SubRelease();
	c22->SubRelease();
	
	return  nResult;
}

//分治
void MatrixMul2(Matrix *a, Matrix *b, Matrix *c)
{
	if (CheckMatrix(a, b, c) == false)
		return;

	Matrix *nResult = MatrixMulMarge(a, b);
	cout << nResult <<endl;
	delete nResult;
	nResult = NULL;
}

Matrix * MatrixMulMarge_ex(Matrix *a, Matrix *b)
{
	Matrix *nResult = new Matrix();
	CheckMatrix(a, b, nResult);

	Matrix *c11 = nResult->GetHalfSubMatrix(11);
	Matrix *c12 = nResult->GetHalfSubMatrix(12);
	Matrix *c21 = nResult->GetHalfSubMatrix(21);
	Matrix *c22 = nResult->GetHalfSubMatrix(22);

	if (a->row <= 1 || a->col <= 1)
	{
		MatrixMul(a, b, c11);
	}
	else{
		Matrix *a11 = a->GetHalfSubMatrix(11);
		Matrix *a12 = a->GetHalfSubMatrix(12);
		Matrix *a21 = a->GetHalfSubMatrix(21);
		Matrix *a22 = a->GetHalfSubMatrix(22);
		cout << "a = " << endl << a11 << " " << a12 << " " << a21 << " " << a22 << endl;

		Matrix *b11 = b->GetHalfSubMatrix(11);
		Matrix *b12 = b->GetHalfSubMatrix(12);
		Matrix *b21 = b->GetHalfSubMatrix(21);
		Matrix *b22 = b->GetHalfSubMatrix(22);

		cout << "b = " << endl << b11 << " " << b12 << " " << b21 << " " << b22 << endl;

		// s1 = b12 -b22
		Matrix* S1 = new Matrix();
		CheckMatrix(b12, b22, S1);
		MatrixSub(b12, b22, S1);

		// s2 = a11 + a12
		Matrix* S2 = new Matrix();
		CheckMatrix(a11, a12, S2);
		MatrixAdd(a11, a12, S2);

		// s3 = a21 + a22
		Matrix* S3 = new Matrix();
		CheckMatrix(a21, a22, S3);
		MatrixAdd(a21, a22, S3);

		// s4 = b21 - b11
		Matrix* S4 = new Matrix();
		CheckMatrix(b21, b11, S4);
		MatrixSub(b21, b11, S4);

		// s5 = a11 + a22
		Matrix* S5 = new Matrix();
		CheckMatrix(a11, a22, S5);
		MatrixAdd(a11, a22, S5);

		// s6 = b11 + b22
		Matrix* S6 = new Matrix();
		CheckMatrix(b11, b22, S6);
		MatrixAdd(b11, b22, S6);

		// s7 = a12 - a22
		Matrix* S7 = new Matrix();
		CheckMatrix(a12, a22, S7);
		MatrixSub(a12, a22, S7);

		// s8 = b21 + b22
		Matrix* S8 = new Matrix();
		CheckMatrix(b21, b22, S8);
		MatrixAdd(b21, b22, S8);

		// s9 = a11 - a21
		Matrix* S9 = new Matrix();
		CheckMatrix(a11, a21, S9);
		MatrixSub(a11, a21, S9);

		// s10 =b11 + b12
		Matrix* S10 = new Matrix();
		CheckMatrix(b11, b12, S10);
		MatrixAdd(b11, b12, S10);
		
		//递归求解
		Matrix* P1 = MatrixMulMarge_ex(a11, S1);

		Matrix* P2 = MatrixMulMarge_ex(S2, b22);

		Matrix* P3 = MatrixMulMarge_ex(S3, b11);
		
		Matrix* P4 = MatrixMulMarge_ex(a22, S4);

		Matrix* P5 = MatrixMulMarge_ex(S5, S6);

		Matrix* P6 = MatrixMulMarge_ex(S7, S8);

		Matrix* P7 = MatrixMulMarge_ex(S9, S10);

		//c11 = P5 + P4 - P2 + P6
		MatrixAdd(P5, P4, c11);
		MatrixSub(c11, P2, c11);
		MatrixAdd(c11, P6, c11);

		//c12 = P1 + P2
		MatrixAdd(P1, P2, c12);

		//c21 = P3 + P4
		MatrixAdd(P3, P4, c21);

		//c22 = P5 + P1 - P3 - P7
		MatrixAdd(P5, P1, c22);
		MatrixSub(c22, P3, c22);
		MatrixSub(c22, P7, c22);

		SAFEDELETE(S1);
		SAFEDELETE(S2);
		SAFEDELETE(S3);
		SAFEDELETE(S4);
		SAFEDELETE(S5);
		SAFEDELETE(S6);
		SAFEDELETE(S7);
		SAFEDELETE(S8);
		SAFEDELETE(S9);
		SAFEDELETE(S10);

		SAFEDELETE(P1);
		SAFEDELETE(P2);
		SAFEDELETE(P3);
		SAFEDELETE(P4);
		SAFEDELETE(P5);
		SAFEDELETE(P6);
		SAFEDELETE(P7);

		a11->SubRelease();
		a12->SubRelease();
		a21->SubRelease();
		a22->SubRelease();

		b11->SubRelease();
		b12->SubRelease();
		b21->SubRelease();
		b22->SubRelease();
	}



	c11->SubRelease();
	c12->SubRelease();
	c21->SubRelease();
	c22->SubRelease();

	return  nResult;

}
//strassen算法
void MatrixMul3(Matrix *a, Matrix *b)
{
	Matrix *nResult = MatrixMulMarge_ex(a, b);
	cout << nResult << endl;
	delete nResult;
	nResult = NULL;
}
int main()
{
	srand(time(NULL));
	Matrix a(2, 2);
	a.RandData();
	Matrix b(2, 2);
	b.RandData();
	Matrix c;
	cout << "a = " << endl << a;
	cout << "b = " << endl << b;

	//MatrixMul2(&a, &b, &c);
	MatrixMul3(&a, &b);
	//cout << "c = " << endl << c;
	/*MatrixMul(&a, &b, &c);
	cout << "c = " << endl << c;*/
	
	return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值