question
完整代码
#include <iostream>
#include <vector>
#include <iomanip>
#define MAX_SIZE 1000
struct Trida
{
int row, col;
int val;
};
struct SparseMatrix
{
Trida data[MAX_SIZE];
int m, n, len;
};
void buildSparseMatrix(SparseMatrix &Mat);
void insertSparseMatrix(SparseMatrix &Mat, int row, int col, int val, int pos);
int GetValue(SparseMatrix Mat, int i, int j);
SparseMatrix Multiply(SparseMatrix A, SparseMatrix B);
void showSparseMatrix(SparseMatrix Mat);
void SparsetoNormal(SparseMatrix Mat);
int main()
{
SparseMatrix A, B;
std::cin >> A.m >> A.n;
buildSparseMatrix(A);
std::cin >> B.m >> B.n;
buildSparseMatrix(B);
if (A.n != B.m)
{
std::cout << "输入的矩阵不符合乘法运算!" << std::endl;
return 0;
}
else
{
SparseMatrix result = Multiply(A, B);
showSparseMatrix(result);
system("pause");
return 0;
}
}
void buildSparseMatrix(SparseMatrix &Mat)
{
int row, col, val, cnt = 0;
while (true)
{
std::cin >> row >> col >> val;
if (row == 0 && col == 0 && val == 0)
break;
else
{
Mat.data[cnt].row = row;
Mat.data[cnt].col = col;
Mat.data[cnt].val = val;
cnt++;
}
}
Mat.len = cnt;
}
void insertSparseMatrix(SparseMatrix &Mat, int row, int col, int val, int pos)
{
Mat.data[pos].row = row;
Mat.data[pos].col = col;
Mat.data[pos].val = val;
}
int GetValue(SparseMatrix Mat, int row, int col)
{
int k = 0;
while (k < Mat.len && (Mat.data[k].row != row || Mat.data[k].col != col))
{
k++;
}
if (k < Mat.len)
return Mat.data[k].val;
return 0;
}
SparseMatrix Multiply(SparseMatrix A, SparseMatrix B)
{
SparseMatrix C;
C.m = A.m, C.n = B.n;
for (auto i = 0; i < C.m * C.n; ++i)
{
C.data[i].row = 0;
C.data[i].col = 0;
C.data[i].val = 0;
}
int pos = 0, temp = 0;
for (auto i = 0; i < A.m; ++i)
{
for (auto j = 0; j < B.n; ++j)
{
for (auto k = 0; k < B.m; ++k)
temp += GetValue(A, i, k) * GetValue(B, k, j);
if (temp == 0)
continue;
insertSparseMatrix(C, i, j, temp, pos);
temp = 0, pos++;
}
}
C.len = pos;
return C;
}
void showSparseMatrix(SparseMatrix Mat)
{
for (auto i = 0; i < Mat.len; ++i)
{
std::cout << Mat.data[i].row << " " << Mat.data[i].col << " " << Mat.data[i].val << std::endl;
}
}
void SparsetoNormal(SparseMatrix Mat)
{
std::vector<std::vector<int>> temp(Mat.m, std::vector<int>(Mat.n, 0));
for (auto i = 0; i < Mat.len; ++i)
{
temp[Mat.data[i].row][Mat.data[i].col] = Mat.data[i].val;
}
for (auto i = 0; i < Mat.m; ++i)
{
for (auto j = 0; j < Mat.n; ++j)
{
std::cout << std::setw(3) << temp[i][j] << ' ';
}
std::cout << std::endl;
}
}