十字链表存储的稀疏矩阵的加法运算
题目描述
给定两个用十字链表存储的稀疏矩阵,你需要编写一个程序,实现这两个矩阵的加法运算。
输入格式
输入包含两个部分,每个部分描述一个稀疏矩阵。
每个部分的第一行包含两个整数 n n n 和 m m m,表示矩阵的行数和列数。
接下来的 n n n 行,每行 m m m 个整数,表示该矩阵的元素。其中,0表示该位置没有元素。
输出格式
输出一个矩阵,表示两个输入矩阵的和。矩阵的格式与输入格式相同。
样例 #1
样例输入 #1
3 3
1 0 0
0 0 5
3 4 0
1 2 0
0 0 5
3 0 0
样例输出 #1
2 2 0
0 0 10
6 4 0
思路
对于矩阵 B
中的每一个非零元素,都需要在矩阵 A
中找到相应的位置,然后进行相应的操作。这个过程可以分为三种情况:
-
如果矩阵
A
中对应位置没有非零元素,即A
中的三元组在相应位置是零,那么就直接将B
中的元素添加到A
的相应位置。这个操作通过创建一个新的OLNode
节点,然后将其插入到A
的行链表和列链表的正确位置来完成。 -
如果矩阵
A
中对应位置有非零元素,并且A
和B
中的元素相加不为零,那么就更新A
中对应位置的元素值。这个操作通过将B
中的元素值加到A
中的元素值上来完成。 -
如果矩阵
A
中对应位置有非零元素,但是A
和B
中的元素相加为零,那么就需要删除A
中的这个元素。这个操作通过修改A
的行链表和列链表,将对应的OLNode
节点从链表中移除,然后释放该节点的内存来完成。
算法分析
时间复杂度
以上三种情况都需要遍历 A
的行链表和列链表,找到相应的位置。这个过程的时间复杂度为
O
(
k
)
O(k)
O(k),其中
k
k
k 是矩阵中非零元素的数量。
空间复杂度
因为每个非零元素都需要进行查找和可能的添加或删除操作,所以空间复杂度也为 O ( k ) O(k) O(k),需要存储每个非零元素的行索引、列索引、元素值以及两个指针。
代码
#include <algorithm>
#include <iostream>
#define AUTHOR "HEX9CF"
using namespace std;
using Status = int;
using ElemType = int;
const int N = 1e4 + 7;
const int TRUE = 1;
const int FALSE = 0;
const int OK = 1;
const int ERROR = 0;
const int INFEASIBLE = -1;
// const int OVERFLOW = -2;
int n, m;
ElemType a[N][N], b[N][N];
struct OLNode {
int i, j;
ElemType e;
OLNode *right, *down;
};
using OLink = OLNode *;
struct CrossList {
OLink *rhead, *chead;
int m, n, num;
};
Status createSMatrixOL(CrossList &M, int m, int n) {
M.m = m;
M.n = n;
M.num = m * n;
M.rhead = (OLink *)malloc((m + 1) * sizeof(OLink));
if (!M.rhead) {
return ERROR;
}
M.chead = (OLink *)malloc((m + 1) * sizeof(OLink));
if (!M.chead) {
return ERROR;
}
for (int i = 0; i <= m; i++) {
M.rhead[i] = nullptr;
}
for (int i = 0; i <= n; i++) {
M.chead[i] = nullptr;
}
}
Status add(CrossList &M, int i, int j, ElemType e) {
OLNode *p = (OLNode *)malloc(sizeof(OLNode));
if (!p) {
return ERROR;
}
p->i = i;
p->j = j;
p->e = e;
if (!M.rhead[i] || !M.rhead[i]->j > j) {
p->right = M.rhead[i];
M.rhead[i] = p;
} else {
OLNode *q = M.rhead[i];
while (q->right && q->right->j < j) {
q = q->right;
}
p->right = q->right;
q->right = p;
}
if (!M.chead[j] || M.chead[j]->i > i) {
p->down = M.chead[j];
M.chead[j] = p;
} else {
OLNode *q = M.chead[j];
while (q->down && q->down->i < i) {
q = q->down;
}
p->down = q->down;
q->down = p;
}
return OK;
}
Status display(CrossList &M) {
for (int i = 1; i <= M.m; i++) {
OLNode *rowNode = M.rhead[i];
for (int j = 1; j <= M.n; j++) {
if (rowNode && rowNode->j == j) {
cout << rowNode->e << " ";
rowNode = rowNode->right;
} else {
cout << 0 << " ";
}
}
cout << "\n";
}
cout << "\n";
}
Status addMatrix(CrossList &A, CrossList &B) {
if (A.m != B.m || A.n != B.n) {
return ERROR;
}
for (int i = 1; i <= A.m; i++) {
OLNode *bRowNode = B.rhead[i];
while (bRowNode) {
OLNode *aRowNode = A.rhead[i];
OLNode *aLeftNode = nullptr;
while (aRowNode && aRowNode->j < bRowNode->j) {
aLeftNode = aRowNode;
aRowNode = aRowNode->right;
}
if (aRowNode && aRowNode->j == bRowNode->j) {
// 提取到的 B 中三元组在 A 相应位置上有非 0 元素
aRowNode->e += bRowNode->e;
// 相加为 0
if (aRowNode->e == 0) {
// 删除矩阵 A 中对应结点
if (aLeftNode) {
aLeftNode->right = aRowNode->right;
} else {
A.rhead[i] = aRowNode->right;
}
OLNode *aColNode = A.chead[i];
OLNode *aUpNode = nullptr;
while (aColNode && aColNode->i < bRowNode->i) {
aUpNode = aColNode;
aColNode = aColNode->down;
}
if (aUpNode) {
aUpNode->down = aColNode->down;
} else {
A.chead[i] = aColNode->down;
}
free(aRowNode);
aRowNode = nullptr;
}
} else {
// 提取到的 B 中的三元组在 A 相应位置上没有非 0 元素
// 直接加到矩阵 A 该行链表的对应位置上
OLNode *newNode = (OLNode *)malloc(sizeof(OLNode));
if (!newNode) {
return ERROR;
}
if (aLeftNode) {
aLeftNode->right = newNode;
} else {
A.rhead[i] = newNode;
}
OLNode *aColNode = A.chead[i];
OLNode *aUpNode = nullptr;
while (aColNode && aColNode->i < bRowNode->i) {
aUpNode = aColNode;
aColNode = aColNode->down;
}
if (aUpNode) {
aUpNode->down = newNode;
} else {
A.chead[i] = newNode;
}
newNode->i = bRowNode->i;
newNode->j = bRowNode->j;
newNode->e = bRowNode->e;
newNode->right = aRowNode;
newNode->down = aColNode;
}
bRowNode = bRowNode->right;
}
}
return OK;
}
int main() {
cin >> m >> n;
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
cin >> a[i][j];
}
}
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
cin >> b[i][j];
}
}
CrossList A;
createSMatrixOL(A, m, n);
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
add(A, i + 1, j + 1, a[i][j]);
}
}
CrossList B;
createSMatrixOL(B, m, n);
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
add(B, i + 1, j + 1, b[i][j]);
}
}
// display(A);
// display(B);
addMatrix(A, B);
display(A);
return 0;
}