Salmon_lee的博客

在学习中不断提升~

仿Matlab 矩阵类

代码说明:

①模仿Matlab的思想对矩阵进行操作

②利用(double*) data动态分配内存来保存矩阵的的元素


主要功能(持续更新):


运行实例:


Matrix.h(持续更新):

#include <iostream>
#include <iomanip>
#include <random>
using namespace std;

void swap( double& a, double& b )
{
	double tmp = a;
	a = b;
	b = tmp;
}

class Matrix {
private:
	int row = 0;
	int col = 0;
	double* data = nullptr;													//存储数据
	static int precision;													//设置小数位数

public:
	inline int getRow() const;												//行
	inline int getCol() const;												//列
	inline double& getValue( int i, int j );								//获取值(返回左值引用)
	inline void setValue( const int* a );									//利用数组设置元素
	inline void random( int low, int high );								//随机元素[low,high]
	static void setPrecision( int a );										//设置输出时的小数位数

	Matrix( const Matrix& a );												//拷贝构造函数
	Matrix( int r, int c );													//构造1(行,列)
	Matrix( int r, int c, const int* a );									//构造2(行,列,元素初始化)
	~Matrix();																//析构函数

	Matrix& operator = ( const Matrix& a );									//赋值运算符重载
	friend bool operator == ( const Matrix& a, const Matrix& b ); 			//判断相等

	friend istream& operator >> ( istream& input, Matrix& a );				//输入
	friend ostream& operator << ( ostream& output, const Matrix& a );		//输出

	friend Matrix operator + ( const Matrix& a, const Matrix& b );			// +
	friend Matrix operator - ( const Matrix& a, const Matrix& b );			// -
	friend Matrix operator * ( const Matrix& a, const Matrix& b );			// 矩阵乘法
	friend Matrix operator * ( const Matrix& a, double b );					// 数乘1
	friend Matrix operator * ( double b, const Matrix& a );					// 数乘2

	void swapRow( int i0, int i1 );											//交换两行
	void linearChange( int des, double times, int src );					//des += times*src
	int legalify( int col );												//使基准行合法

	double det();															//行列式(高斯消元法)

	Matrix operator ^ ( int b );											//矩阵快速幂(log b)
	Matrix triu();															//上三角矩阵(高斯消元法)
	Matrix inv();															//求逆(高斯-亚当思想)		【0.01秒内可求100阶】
	Matrix tran();															//转置
	static Matrix eye( int n );												//返回一个n阶单位阵
};

int Matrix::precision = 4;

int Matrix::getRow() const {
	return row;
}

int Matrix::getCol() const {
	return col;
}

double& Matrix::getValue( int i, int j ) {
	if (1 <= i && i <= row && 1 <= j && j <= col)
		return data[(i - 1) * col + j - 1];
	else
	{
		cerr << "引用非法数据" << endl;
		return data[0];
	}
}

Matrix::Matrix( const Matrix & a ) {
	row = a.row;
	col = a.col;
	data = new double[row * col];
	memcpy( data, a.data, sizeof( double ) * row * col );
}

Matrix::Matrix( int r, int c ) {
	row = r;
	col = c;
	data = new double[r * c];
	memset( data, 0, r * c * sizeof( double ) );
}

Matrix::Matrix( int r, int c, const int* a )
{
	row = r;
	col = c;
	setValue( a );
}

Matrix::~Matrix()
{
	if (data != nullptr)
		delete[] data;
}

Matrix& Matrix::operator = ( const Matrix & a )
{
	row = a.row;
	col = a.col;
	if (data != nullptr) delete[] data;
	data = new double[row*col];
	memcpy( data, a.data, sizeof( double ) * row * col );
	return *this;
}

void Matrix::setValue( const int* a ) {
	memcpy( data, a, sizeof( int ) * row * col );
}

inline void Matrix::random( int low, int high )
{
	default_random_engine e;
	uniform_real_distribution<double> u( low, high );
	for (int i = 0; i < row * col; ++i)
		data[i] = u( e );
}

inline void Matrix::setPrecision( int a )
{
	precision = a;
}

bool operator == ( const Matrix& a, const Matrix& b ) {
	if (a.row != b.row || a.col != b.col)
		return false;
	//判断是否相等
	for (int i = 0; i < a.row * a.col; ++i)
		if (a.data[i] != b.data[i])
			return false;

	return true;
}

istream& operator >> ( istream& input, Matrix& a )
{
	for (int i = 0; i < a.row * a.col; ++i)
		input >> a.data[i];
	return input;
}

ostream& operator << ( ostream& output, const Matrix& a )
{
	for (int i = 1; i <= a.row * a.col; ++i)
	{
		output << fixed << setprecision( Matrix::precision ) << (a.data[i - 1] == 0 ? 0.0 : a.data[i - 1]) << " ";
		if (i % a.col == 0) output << endl;
	}
	return output;
}

Matrix operator + ( const Matrix& a, const Matrix& b ) {
	if (a.row != b.row || a.col != b.col) {
		cerr << ("类型不同,无法相加") << endl;
		return Matrix( 0, 0 );
	}

	Matrix temp( a.col, a.row );
	for (int i = 0; i < a.row * a.col; ++i)
		temp.data[i] = a.data[i] + b.data[i];

	return temp;
}
Matrix operator - ( const Matrix& a, const Matrix& b ) {
	if (a.row != b.row || a.col != b.col) {
		cerr << ("类型不同,无法相减") << endl;
		return Matrix( 0, 0 );
	}

	Matrix temp( a );
	for (int i = 0; i < a.row * a.col; ++i)
		temp.data[i] -= b.data[i];

	return temp;
}
Matrix operator * ( const Matrix& a, const Matrix& b ) {
	if (a.col != b.row) {
		cerr << ("类型不符,无法相乘") << endl;
		return Matrix( 0, 0 );
	}
	Matrix temp( a.row, b.col );
	for (int i = 1; i <= a.row; ++i)
	{
		for (int j = 1; j <= b.col; ++j)
			for (int k = 1; k <= a.row; ++k)
				temp.data[(i - 1) * temp.col + j - 1] += a.data[(i - 1) * a.col + k - 1] * b.data[(k - 1) * b.col + j - 1];
	}
	return temp;
}

Matrix operator * ( const Matrix& a, double b )
{
	Matrix tmp( a );
	for (int i = 0; i < tmp.row * tmp.col; ++i)
		tmp.data[i] *= b;
	return tmp;
}

Matrix operator * ( double b, const Matrix& a )
{
	Matrix tmp( a );
	for (int i = 0; i < tmp.row * tmp.col; ++i)
		tmp.data[i] *= b;
	return tmp;
}

void Matrix::swapRow( int i0, int i1 )
{
	for (int j = 1; j <= col; ++j)
		swap( getValue( i0, j ), getValue( i1, j ) );
}

void Matrix::linearChange( int des, double times, int src )					//des += times src
{
	for (int j = 1; j <= col; ++j)
		getValue( des, j ) += times * getValue( src, j );
}

int Matrix::legalify( int col )
{
	if (getValue( col, col ) != 0) return 1;								//若基准行的基准数不为0,返回1

	for (int legal_row = col + 1; legal_row <= row; ++legal_row)			//若基准行的基准数为0,则向下寻找非0的行
	{
		if (getValue( legal_row, col ) != 0)
		{
			swapRow( legal_row, col );										//若找到,交换两行,并改变行列式符号
			return -1;
		}
	}
	return 0;																//若没找到,返回0
}

double Matrix::det()
{
	Matrix tmp( *this );
	if (tmp.row != tmp.col)
	{
		cout << "此矩阵无行列式" << endl;
		return -1.0;
	}
	double times, res = 1, flag;											//times为线性系数,flag为符号位
	for (int j = 1; j < col; ++j)
	{
		flag = tmp.legalify( j );
		if (flag == 0) return 0.0;											//若此列全为0,则直接进入下一列
		res *= flag;

		for (int i = j + 1; i <= row; ++i)
		{
			if (tmp.getValue( i, j ) == 0) continue;
			times = -tmp.getValue( i, j ) / tmp.getValue( j, j );
			tmp.linearChange( i, times, j );								//线性变换, 使目标行的目标值为0
		}
	}

	for (int i = 1; i <= row; ++i)
		res *= tmp.getValue( i, i );
	return res == 0 ? 0.0 : res;
}

Matrix Matrix::operator ^ ( int b )
{
	if (col != row)
	{
		cerr << "不是方阵,无法求幂";
		return Matrix( row, col );
	}

	if (b == 0) return Matrix::eye( row );
	if (b < 0) return inv() ^ (-b);

	Matrix base( *this );
	Matrix res( Matrix::eye( row ) );

	while (b)
	{
		if (b & 1) res = res * base;
		base = base * base;
		b >>= 1;
	}
	return res;

}

Matrix Matrix::triu()
{
	double times;
	Matrix tmp( *this );
	for (int j = 1; j < row; ++j)
	{
		tmp.legalify( j );													//若此列全为0,则直接进入下一列
		for (int i = j + 1; i <= row; ++i)
		{
			if (tmp.getValue( i, j ) == 0) continue;
			times = -tmp.getValue( i, j ) / tmp.getValue( j, j );
			tmp.linearChange( i, times, j );								//线性变换, 使目标行的目标值为0
		}
	}
	return tmp;
}

Matrix Matrix::inv()
{
	if (row != col || det() == 0)											//注:未考虑 广义逆矩阵
	{
		cerr << "无逆矩阵" << endl;
		return Matrix( 0, 0 );
	}

	//配置AE
	Matrix AE( row, 2 * col );
	for (int i = 1; i <= row; ++i)
		for (int j = 1; j <= col; ++j)
			AE.getValue( i, j ) = getValue( i, j );

	for (int i = 1; i <= row; ++i)
		AE.getValue( i, row + i ) = 1;

	AE = AE.triu();															//A化为上三角

	double times;
	for (int j = col; j > 1; --j)
		for (int i = j - 1; i >= 1; --i)
		{
			times = -AE.getValue( i, j ) / AE.getValue( j, j );				//A化为对角矩阵
			AE.linearChange( i, times, j );
		}

	for (int i = 1; i <= AE.row; ++i)										//对角矩阵化为E
	{
		times = AE.getValue( i, i );
		if (times == 1) continue;
		for (int j = 1; j <= AE.col; ++j)
			AE.getValue( i, j ) /= times;
	}

	Matrix inver( row, col );
	for (int i = 1; i <= row; ++i)
		for (int j = 1; j <= col; ++j)
			inver.getValue( i, j ) = AE.getValue( i, j + col );
	return inver;
}

Matrix Matrix::tran()
{
	Matrix t( col, row );
	for (int i = 1; i <= t.row; ++i)
		for (int j = 1; j <= t.row; ++j)
			t.getValue( i, j ) = getValue( j, i );

	return t;
}

Matrix Matrix::eye( int n )
{
	Matrix e( n, n );
	for (int i = 1; i <= n; ++i)
		e.getValue( i, i ) = 1;
	return e;
}

参考资料:

Matlab矩阵相关函数

临时变量作为引用传参

阅读更多
版权声明: https://blog.csdn.net/leelitian3/article/details/79948666
个人分类: 离散&线代
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

不良信息举报

仿Matlab 矩阵类

最多只允许输入30个字

加入CSDN,享受更精准的内容推荐,与500万程序员共同成长!
关闭
关闭