定义一个二维方阵类 matrix
通过重载二元运算符“+”、“-”、“*”和一元运算符“~”, 来实现矩阵加、矩阵减、矩阵乘以及矩阵转置。
matrix类的构造、拷贝构造及析构
1.由于矩阵的行与列都是未知的,首先需要通过动态分配内存实现创建任意大小的矩阵,由于类中默认的构造函数无法满足我们的需求,因此首先应该改写构造函数
matrix(int a, int b) {
r = a;
c = b;
mem = new int* [a];
for (int i = 0; i < a; i++) {
mem[i] = new int[b];
}
};
2.类中构造函数出现动态分配内存,就要考虑深拷贝与浅拷贝的问题,由于下文实现矩阵运算功能的函数中某些函数在传参时需要拷贝,因此必须通过深拷贝进行,否则会引发一系列问题,应该改写拷贝构造函数
matrix(const matrix& p) {
r = p.r;
c = p.c;
mem = new int* [r];
for (int i = 0; i < r; i++) {
mem[i] = new int[c];
}
for (int i = 0; i < r; i++)
for (int j = 0; j < c; j++) {
mem[i][j] = p.mem[i][j];
}
}
3.同样的,析构函数应释放动态分配的内存,这里也需要改写
~matrix() {
for (int i = 0; i < r; i++) {
delete[]mem[i];
}
delete[]mem;
};
matrix的运算符重载
1.矩阵加法(减法)
首先应判断两个矩阵的行宽和列宽是否一致,满足一致后开始加法运算
其次,传参时用const matrix& m
是为了防止引用造成原内容被修改(引用传参前能加const尽量加)
注意,这里便体现了改写拷贝构造函数的用处,return tmp;
发生了一次拷贝构造,如果没有改写,会引发浅拷贝的问题。
matrix operator+ (const matrix& m) {
//应判断两个矩阵的行宽和列宽是否一致
if (r != m.r || c != m.c) {
cout << "error";
matrix tmp(r, c);
tmp.mem = NULL;
return tmp;
}
else {
matrix tmp(r, c);
for (int i = 0; i < r; i++)
for (int j = 0; j < c; j++)
tmp.mem[i][j] = mem[i][j] + m.mem[i][j];
return tmp;
}
}
易错点:如下代码看似没有采用拷贝构造,但是违背了加法运算的本质(*this的值被修改)
matrix operator+ (matrix& m)//矩阵加
{
for (int i = 0; i < r; i++)
for (int j = 0; j < c; j++)
mem[i][j] = mem[i][j] + m.mem[i][j];
return *this;
}
减法原理与加法类似,不做赘述
2.矩阵乘法
matrix operator* (const matrix& m)//矩阵乘
{
if (c != m.r) {
cout << "error";
matrix tmp(r, c);
tmp.mem = NULL;
return tmp;
}
else {
matrix tmp(r, m.c);
for (int i = 0; i < r; i++)
for (int j = 0; j < m.c; j++)
tmp.mem[i][j] = 0;
for (int i = 0; i < tmp.r; i++)
for (int j = 0; j < tmp.c; j++)
for (int k = 0; k < c; k++)
tmp.mem[i][j] += (mem[i][k] * m.mem[k][j]);
return tmp;
}
}
3.矩阵转置
matrix operator~ ()//矩阵转置
{
matrix tmp(c, r);
for (int i = 0; i < r; i++)
for (int j = 0; j < c; j++)
tmp.mem[j][i] = mem[i][j];
return tmp;
}
4.=运算符重载
易错点matrix operator=(const matrix& m)
:如果返回值不是引用,当进行(a=b)=c时,a=b返回的不是a本身,而是一个临时变量,那么(a=b)=c相当于c的值最后没有赋值给a
matrix & operator=(const matrix& m)
{
if (c != m.r) {
cout << "error";
matrix tmp(r, c);
tmp.mem = NULL;
return tmp;
}
else {
for (int i = 0; i < r; i++)
for (int j = 0; j < c; j++)
mem[i][j] = m.mem[i][j];
return *this;
}
}
5.输出函数(用于打印结果)
friend void display(matrix m)//输出矩阵
{
for (int i = 0; i < m.r; i++) {
for (int j = 0; j < m.c; j++) {
cout << m.mem[i][j] << ' ';
}
cout << endl;
}
cout << "====================================================================================" << endl;
}
完整代码实现:
#include <iostream>
using namespace std;
class matrix {
public:
int r, c;
int** mem;
matrix(int a, int b) {
r = a;
c = b;
mem = new int* [a];
for (int i = 0; i < a; i++) {
mem[i] = new int[b];
}
};
matrix(const matrix& p) {
r = p.r;
c = p.c;
mem = new int* [r];
for (int i = 0; i < r; i++) {
mem[i] = new int[c];
}
for (int i = 0; i < r; i++)
for (int j = 0; j < c; j++) {
mem[i][j] = p.mem[i][j];
}
}
~matrix() {
for (int i = 0; i < r; i++) {
delete[]mem[i];
}
delete[]mem;
};
matrix operator+ (const matrix& m) {
if (r != m.r || c != m.c) {
cout << "error";
matrix tmp(r, c);
tmp.mem = NULL;
return tmp;
}
else {
matrix tmp(r, c);
for (int i = 0; i < r; i++)
for (int j = 0; j < c; j++)
tmp.mem[i][j] = mem[i][j] + m.mem[i][j];
return tmp;
}
}
matrix operator- (const matrix& m) {
if (r != m.r || c != m.c) {
cout << "error";
matrix tmp(r, c);
tmp.mem = NULL;
return tmp;
}
else {
matrix tmp(r, c);
for (int i = 0; i < r; i++)
for (int j = 0; j < c; j++)
tmp.mem[i][j] = mem[i][j] - m.mem[i][j];
return tmp;
}
}
matrix operator* (const matrix& m)//矩阵乘
{
if (c != m.r) {
cout << "error";
matrix tmp(r, c);
tmp.mem = NULL;
return tmp;
}
else {
matrix tmp(r, m.c);
for (int i = 0; i < r; i++)
for (int j = 0; j < m.c; j++)
tmp.mem[i][j] = 0;
for (int i = 0; i < tmp.r; i++)
for (int j = 0; j < tmp.c; j++)
for (int k = 0; k < c; k++)
tmp.mem[i][j] += (mem[i][k] * m.mem[k][j]);
return tmp;
}
}
matrix operator~ ()//矩阵转置
{
matrix tmp(c, r);
for (int i = 0; i < r; i++)
for (int j = 0; j < c; j++)
tmp.mem[j][i] = mem[i][j];
return tmp;
}
matrix & operator=(const matrix& m)
{
if (c != m.r) {
cout << "error";
matrix tmp(r, c);
tmp.mem = NULL;
return tmp;
}
else {
for (int i = 0; i < r; i++)
for (int j = 0; j < c; j++)
mem[i][j] = m.mem[i][j];
return *this;
}
}
friend void display(matrix m)//输出矩阵
{
for (int i = 0; i < m.r; i++) {
for (int j = 0; j < m.c; j++) {
cout << m.mem[i][j] << ' ';
}
cout << endl;
}
cout << "====================================================================================" << endl;
}
};
主函数测试结果:
int main() {
matrix p1(2, 3), p2(2, 3), p3(3, 2);
int num1 = 0, num2 = 1, num3 = 2;
for (int i = 0; i < p1.r; i++)
for (int j = 0; j < p1.c; j++) {
p1.mem[i][j] = num1;
num1++;
}
for (int i = 0; i < p2.r; i++)
for (int j = 0; j < p2.c; j++) {
p2.mem[i][j] = num2;
num2++;
}
for (int i = 0; i < p3.r; i++)
for (int j = 0; j < p3.c; j++) {
p3.mem[i][j] = num3;
num3++;
}
matrix p11 = p1 + p2;
cout << "p1" << endl;
display(p1);
cout << "p2" << endl;
display(p2);
cout << "p3" << endl;
display(p3);
cout << "p11" << endl;
display(p11);
cout << "p1+p2" << endl;
display(p1 + p2);
system("pause");
}