实现一个KMatrix矩阵容器类
作业描述
使用C++编写一个KMatrix容器类,内部可以存储一个二维矩阵数据,并满足以下要求:
1、使用所学的标准容器来组织和存储矩阵数据;
2、KMatrix可以存储int/double等常规数值类型,同一个矩阵内部存储的数据类型是一致的;
3、实现KMatrix的初始化函数(KMatrix::init(row_count, col_count))(初始数据为0);
4、实现KMatrix获取行列数的函数(int KMatrix::getRows() const、int KMatrix::getCols() const);
5、实现KMatrix的数据修改与获取函数(KMatrix::setData(row, col, value)、Value KMatrix::getData(row, col) const);
6、实现KMatrix删除行列的函数(KMatrix::erase_row(row)、KMatrix::erase_col(col));
7、实现KMatrix的加(+)、减(-)、叉乘(*) 运算, 使用运算符重载实现;
8、实现KMatrix的转置(KMatrix KMatrix::transpose() const) (交换行列);
9、实现KMatrix的控制台打印输出(KMatrix::print() const) (需要体现矩阵的基本结构);
加分项:
- 可以考虑KMatrix数据的压缩存储方案。(即矩阵中数据为0的元素不占用存储空间),要求仍然以标准容器组合来实现。
- 设计迭代器可以访问矩阵的内容。
实现思路:
稀疏矩阵的存储:行逻辑链接的顺序表(压缩存储稀疏矩阵)
储存矩阵一共有多少行和列,然后利用三元组存储 -> 分别存储非0元素的 (行、列、值)vector <tuple<int, int, T>> m_matrix
按行、列依次增加的顺序储存。实现方法:定义一个vector m_rowPos 用来储存矩阵中,每一行第一个元素在m_matrix的位置。然后遍历就会节省时间,不需要遍历完整个容器里面的内容。添加元素时遍历找到相应的位置,再进行插入,而不是直接push_back。
删除行、列:找到相应的行/列,删除。但是同时还需要修改m_rowPos和行/列的值
矩阵的加、减运算:双迭代(指针)遍历
因为稀疏矩阵的存储的是有序的,所有仅仅要判断相加的矩阵a、b位置的位置知否相同,若位置相同进行值相加(减)然后保存,两个矩阵的指针都需要后移;若不相同的位置就先直接存储a、b中位置相对小的位置和值,后指针后移一位。
乘: ab
a的行b的列,但是列的遍历比较困难,因为矩阵的按行、列顺序存储的,所以把b转置后遍历会更快速。
代码实现
#include <iostream>
#include <vector>
#include <tuple>
#include <assert.h>
using namespace std;
template <typename T>
class KMatrix
{
public:
KMatrix(int row = 0, int column = 0); //构造函数
KMatrix(const KMatrix<T>& matrix); //构造函数
void init(int row, int column); //初始化行列信息
int getRows() const; //获得行
int getCols() const; //获得列.
void setData(int row, int col, T value);//设置值
T getData(int row, int col); //得到值
void eraseRow(int row); //删除行
void eraseCol(int col); //删除列
auto begin() const; //用于迭代访问矩阵的内容,开始
auto end() const; //用于迭代访问矩阵的内容,结束
KMatrix<T> operator+(const KMatrix<T>& b);//加
KMatrix<T> operator-(const KMatrix<T>& b);//减
KMatrix<T> operator*(const KMatrix<T>& b);//乘
KMatrix<T> transpose() const; //转置
void print() const; //输出
private:
int m_rows, m_columns; //是一个 m_rows × m_columns 的矩阵
vector <tuple<int, int, T>> m_matrix; //三元组,分别存储矩阵元素的行、列、值,压缩储存稀疏矩阵
vector <int> m_rowPos; //记录矩阵中每行第一个非 0 元素在m_matrix容器中的存储位置,更好的实现行的遍历和有序储存
vector<tuple<int, int, T>>& getMatrix(); //获取矩阵
};
//构造函数
template <typename T>
KMatrix<T>::KMatrix(int row, int column)
{
init(row, column);
}
//拷贝构造函数,深拷贝
template <typename T>
KMatrix<T>::KMatrix(const KMatrix<T>& matrix)
{
init(matrix.getRows(), matrix.getCols());
auto it = matrix.begin();
for (;it != matrix.end();it++) {
setData(get<0>(*it), get<1>(*it), get<2>(*it));
}
}
//初始化行列信息
template <typename T>
void KMatrix<T>::init(int row, int column)
{
//初始化行、列
m_rows = row;
m_columns = column;
//初始化行逻辑表
for (int i = 0;i <= row + 1;i++)
{
m_rowPos.push_back(0);
}
}
//获得行
template <typename T>
int KMatrix<T>::getRows() const
{
return m_rows;
}
//获得列
template <typename T>
int KMatrix<T>::getCols() const
{
return m_columns;
}
template <typename T>
void KMatrix<T>::setData(int row, int col, T value)//设置值
{
//断言判断合法性
assert(row >= 1 && row <= m_rows && col >= 1 && col <= m_columns, "添加的元素的位置不合法");
//判断是否已经存在
int temp = m_rowPos[row];
bool flag = false;
for (;temp < m_rowPos[row + 1];temp++) {
if (get<0>(m_matrix[temp]) == row && get<1>(m_matrix[temp]) == col) {
flag = true;
break;
}
//稀疏矩阵按行、列依次增大的顺序储存,找不到
if (get<1>(m_matrix[temp]) > col) {
break;
}
}
//若存在,修改值
if (flag) {
get<2>(m_matrix[temp]) = value;
}
//若不存在,添加
else {
//新建一个三元组
tuple<int, int, T> t(row, col, value);
//遍历迭代到对应的位置
auto it = m_matrix.begin();
for (int i = 0;i < temp;i++) {
it++;
}
m_matrix.insert(it, t);
//后面的行
for (int i = row + 1;i <= m_rows + 1;i++) {
m_rowPos[i]++;
}
}
}
//得到值
template <typename T>
T KMatrix<T>::getData(int row, int col)
{
//断言判断合法性
assert(row >= 1 && row <= m_rows && col>=1&& col<=m_columns, "查找的元素的位置不合法");
//判断是否能找到
int temp = m_rowPos[row];
bool flag = false;
for (;temp < m_rowPos[row + 1];temp++) {
if (get<0>(m_matrix[temp]) == row && get<1>(m_matrix[temp]) == col) {
flag = true;
break;
}
if (get<1>(m_matrix[temp]) > col) {
break;
}
}
if (flag) {
return get<2>(m_matrix[temp]);
}
else {
return 0;
}
}
//删除行
template <typename T>
void KMatrix<T>::eraseRow(int row)
{
//断言判断合法性
assert(row >= 1 && row <= m_rows, "删除的行不合法");
int temp = m_rowPos[row];
auto it = m_matrix.begin();
//找到记录这行的开始
for (int i = 0;i < temp;i++) {
it++;
}
//遍历删除这一整行的元素
for (;temp < m_rowPos[row + 1];temp++) {
auto t = it++;
m_matrix.erase(t);
}
//后面的行 pos 需要改变
for (int i = m_rows;i > row;i--) {
m_rowPos[i] = m_rowPos[i-1];
}
//找到记录这行的pos
auto rowIt = m_rowPos.begin();
for (int i = 1;i < row;i++) {
++rowIt;
}
//删除记录这行的pos
m_rowPos.erase(rowIt);
//行数-1
--m_rows;
}
//删除列
template <typename T>
void KMatrix<T>::eraseCol(int col)
{
//断言判断合法性
assert(col >= 1 && col <= m_columns, "删除的列不合法");
auto it = m_matrix.begin();
for (int i = 0;i < m_matrix.size();i++) {
if (get<1>(*it) == col) {
auto t = it++;
//后面的行,pos往前移动一个元素
for (int i = get<0>(*t);i < m_rows;i++) {
--m_rowPos[i + 1];
}
m_matrix.erase(t);
}
else {
it++;
}
}
--m_columns;
}
//用于迭代访问矩阵的内容
//开始
template <typename T>
auto KMatrix<T>::begin() const
{
return m_matrix.begin();
}
//用于迭代访问矩阵的内容
//结束
template <typename T>
auto KMatrix<T>::end() const
{
return m_matrix.end();
}
//加
template <typename T>
KMatrix<T> KMatrix<T>::operator+(const KMatrix<T>& b)
{
if (m_rows != b.getRows() || m_columns != b.getCols()) {
cout << "不符合相加条件" << endl;
return KMatrix<T>(0, 0);
}
//迭代遍历矩阵
auto it1 = m_matrix.begin();
auto it2 = b.begin();
//相加后的新矩阵行、列,与原矩阵一样
//储存相加的结果
KMatrix<T> ans(m_rows, m_columns);
while (it1 != m_matrix.end() || it2 != b.end()) {
if (it1 == m_matrix.end()) { //如果a矩阵已经一个没有非0元素
ans.setData(get<0>(*it2), get<1>(*it2), get<2>(*it2));
++it2;
}
else if (it2 == b.end()) { //如果b矩阵已经一个没有非0元素
ans.setData(get<0>(*it1), get<1>(*it1), get<2>(*it1));
++it1;
}
else if ((get<0>(*it1) == get<0>(*it2)) && (get<1>(*it1) == get<1>(*it2))) { //如果两个位置的行和列相等,直接相加
ans.setData(get<0>(*it1), get<1>(*it1), get<2>(*it1) + get<2>(*it2));
++it1;
++it2;
}
else if ((get<0>(*it1) > get<0>(*it2)) || ((get<1>(*it1) > get<1>(*it2)) && (get<0>(*it1) >= get<0>(*it2)))) { //如果a的位置it1 比b的位置it2 大,直接存储it2元素信息,并且it2往后移
ans.setData(get<0>(*it2), get<1>(*it2), get<2>(*it2));
++it2;
}
else if ((get<0>(*it1) < get<0>(*it2)) || ((get<1>(*it1) < get<1>(*it2)) && (get<0>(*it1) <= get<0>(*it2)))) { //如果b的位置it2 比a的位置it1 大,直接存储it1元素信息,并且it1往后移
ans.setData(get<0>(*it1), get<1>(*it1), get<2>(*it1));
++it1;
}
}
return ans;
}
//减
template <typename T>
KMatrix<T> KMatrix<T>::operator-(const KMatrix<T>& b)
{
//仅需要将b理解为-b,算法就会和加法一样
if (m_rows != b.getRows() || m_columns != b.getCols()) {
cout << "不符合相减条件" << endl;
return KMatrix<T>(0, 0);
}
auto it1 = m_matrix.begin();
auto it2 = b.begin();
//相减后的新矩阵行、列,与原矩阵一样
//储存相加的结果
KMatrix<T> ans(m_rows, m_columns);
while (it1 != m_matrix.end() || it2 != b.end()) {
if (it1 == m_matrix.end()) {
ans.setData(get<0>(*it2), get<1>(*it2), -get<2>(*it2));
++it2;
}
else if (it2 == b.end()) {
ans.setData(get<0>(*it1), get<1>(*it1), get<2>(*it1));
++it1;
}
else if ((get<0>(*it1) == get<0>(*it2)) && (get<1>(*it1) == get<1>(*it2))) {
ans.setData(get<0>(*it1), get<1>(*it1), get<2>(*it1) - get<2>(*it2));
++it1;
++it2;
}
else if ((get<0>(*it1) > get<0>(*it2)) || ((get<1>(*it1) > get<1>(*it2)) && (get<0>(*it1) >= get<0>(*it2)))) {
ans.setData(get<0>(*it2), get<1>(*it2), -get<2>(*it2));
++it2;
}
else if ((get<0>(*it1) < get<0>(*it2)) || ((get<1>(*it1) < get<1>(*it2)) && (get<0>(*it1) <= get<0>(*it2)))) {
ans.setData(get<0>(*it1), get<1>(*it1), get<2>(*it1));
++it1;
}
}
return ans;
}
//叉乘
template <typename T>
KMatrix<T> KMatrix<T>::operator*(const KMatrix<T>& b)
{
if (m_columns != b.getRows()) {
cout << "不符合相乘条件" << endl;
return KMatrix<T>(0, 0);
}
auto it1 = m_matrix.begin();
//储存相乘的结果
KMatrix<T> ans(m_rows, b.getCols());
//转置被乘的矩阵b,方便乘法的遍历(本来是按列遍历,转置后变成按行遍历,按行遍历是有序的)
KMatrix<T> tempB = b.transpose();
//遍历结果矩阵的行
for (int i = 1;i <= m_rows;i++) {
auto it2 = tempB.begin();
//遍历结果矩阵的列
for (int j = 1;j <= b.getCols();j++) {
T sum = 0;
it1 = m_matrix.begin();
for (int k = 0;k < m_rowPos[i];k++) {
it1++; //迭代器遍历到当前行,的第一个元素
}
//遍历 找到符合条件的元素进行 相乘后相加
for (int k = m_rowPos[i];k < m_rowPos[i+1] && it2 != tempB.end();) { //it1(k),it2都需要进行边界判断
if (get<0>(*it2) > j || get<0>(*it1)!=i) //it2的行数已经超过,不管it1有没有遍历完,都结束循环,因为相乘需要两个元素不为0,结果才不为0
break;
if (get<1>(*it1) == get<1>(*it2) && get<0>(*it2)==j) { //位置相等,相乘后相加
sum += get<2>(*it1) * get<2>(*it2);
it1++;
it2++;
k++;
}
else if (get<1>(*it1) < get<1>(*it2) && get<0>(*it2) == j) { //it1的位置小于it2的位置
it1++;
k++;
}
else { //it2的位置小于it1的位置
it2++;
}
}
if (sum != 0) { //矩阵压缩存储,结果为0不需要储存
ans.setData(i, j, sum);
}
}
}
return ans;
}
//转置
template <typename T>
KMatrix<T> KMatrix<T>::transpose() const
{
KMatrix<T> ans(m_columns, m_rows); //行数和列数相交换
auto it = m_matrix.begin();
for (;it != m_matrix.end();it++) {
ans.setData(get<1>(*it), get<0>(*it), get<2>(*it)); //行和列的相交换,就得转置的结果
}
return ans;
}
//输出
template <typename T>
void KMatrix<T>::print() const
{
auto it = m_matrix.begin();
for (int i = 1;i <= m_rows;i++) {
for (int j = 1;j <= m_columns;j++) {
if ((it != m_matrix.end()) && (get<0>(*it) == i && get<1>(*it) == j)) {
cout << get<2>(*it) << " ";
++it;
}
else {
cout << "0 "; //矩阵中不存在,则该位置的值为0
}
}
cout << endl;
}
}
//获取矩阵
template <typename T>
vector<tuple<int, int, T>>& KMatrix<T>::getMatrix()
{
return m_matrix;
}
/**************************************************************************
Author: YaoYuming
Date:2022-04-20
Description: KMatrix矩阵容器类,稀疏矩阵的压缩存储,实现基本运算
**************************************************************************/
#include <iostream>
#include <vector>
#include "KMatrix.cpp"
using namespace std;
int main()
{
KMatrix<int> a(3, 3);
a.setData(1, 1, 1);
a.setData(1, 2, 4);
a.setData(3, 1, 5);
a.setData(2, 2, 2);
a.setData(1, 1, 4);
cout << "a矩阵:" << endl;
a.print();
cout << endl;
KMatrix<int> b(3, 3);
b.setData(1, 1, 1);
b.setData(2, 2, 2);
b.setData(3, 3, 3);
cout << "b矩阵:" << endl;
b.print();
cout << endl;
cout << "a+b:" << endl;
KMatrix<int> add = a + b;
add.print();
cout << endl;
cout << "a-b:" << endl;
KMatrix<int> sub = a - b;
sub.print();
cout << endl;
cout << "a转置矩阵:" << endl;
KMatrix<int> trans = a.transpose();
trans.print();
cout << endl;
cout << "a*b:" << endl;
KMatrix<int> mul1 = a*b;
mul1.print();
cout << endl;
cout << "b*a:" << endl;
KMatrix<int> mul2 = b * a;
mul2.print();
cout << endl;
KMatrix<double> c(3, 3);
c.setData(1, 1, 1.1);
c.setData(2, 2, 2.2);
c.setData(3, 3, 3.3);
cout << "c矩阵:" << endl;
c.print();
c.eraseRow(3);
cout << "c矩阵:" << endl;
c.print();
cout << "获取c的行、列" << endl;
cout << c.getRows() << " " << c.getCols() << endl;
cout << "获取c的特定位置的值" << endl;
cout << c.getData(1, 1) << endl;
cout << c.getData(1, 2) << endl;
KMatrix<double> d(3, 3);
d.setData(1, 1, 5.3);
d.setData(2, 1, 6.2);
d.setData(3, 3, 7.3);
cout << endl << "d矩阵:" << endl;
d.print();
d.eraseCol(3);
cout << "d矩阵:" << endl;
d.print();
cout << "遍历输出d矩阵所储存的数据:" << endl;
for (auto it = d.begin();it != d.end();it++) {
cout << get<0>(*it) << " "
<< get<1>(*it) << " "
<< get<2>(*it) << endl;
}
}