//cross.h #include <iostream> using namespace std; struct Node { short row; short col; int num; Node * right; //同一行的下一个元素 Node * down; //同一列的下一个元素 Node() { row = col = -1; num = 0; right = down = NULL; } Node(short _row, short _col, int _val) { row = _row; col = _col; num = _val; right = down = NULL; } }; class Cross { private: int m_row; //行数 int m_col; //列数 Node ** pRow; Node ** pCol; Cross(){} public: Cross(int row, int col); int GetRow() const { return m_row; } int GetCol() const { return m_col; } ~Cross(); void GetChainFromMatrix(const int * arr); void Add(int row, int col, int num); void Multiplication(const Cross & multiplicator, Cross & result); void GetMatrix(int * arr); void Show(); }; //cross.cpp #include "cross.h" Cross::Cross(int row, int col):m_row(row), m_col(col) { pRow = new Node*[m_row]; for(int i = 0; i < m_row; ++i) { pRow[i] = NULL; } pCol = new Node*[m_col]; for(/*int */i = 0; i < m_col; ++i) { pCol[i] = NULL; } } Cross::~Cross() { for(int i = 0; i < m_row; ++i) { Node * temp = pRow[i]; while(temp) { pRow[i] = temp->right; delete temp; temp = pRow[i]; } } delete [] pRow; delete [] pCol; } //此处arr的维数应和m_row,m_col相同,可通过GetRow,GetCol获得 void Cross::GetChainFromMatrix(const int * arr) { for(int i = 0; i < m_row; ++i) { for(int j = 0; j < m_col; ++j) { if(arr[i * m_row + j] == 0) //如果是零元,则继续 continue; Node * add = new Node(i, j, arr[i * m_row + j]); if(pRow[i]) { Node * temp = pRow[i]; //链接到行上 while(temp) { //此处排序,若按从左至右,从上至下的顺序进行,则已有序 /*if(temp->right && temp->right->col > add->col) { add->right = temp->right; temp->right = add; break; }*/ if(temp->right == NULL) { temp->right = add; break; } temp = temp->right; } } else pRow[i] = add; if(pCol[j]) { Node * temp = pCol[j]; //链接到列上 while(temp) { //此处排序,若按从左至右,从上至下的顺序进行,则已有序 /*if(temp->down && temp->down->row > add->row) { add->down = temp->down; temp->down = add; break; }*/ if(temp->down == NULL) { temp->down = add; break; } temp = temp->down; } } else pCol[j] = add; } } } void Cross::Add(int row, int col, int num) { Node * add = new Node(row, col, num); if(pRow[row]) { Node * temp = pRow[row]; //链接到行上 while(temp) { //此处排序,若按从左至右,从上至下的顺序进行,则已有序 if(temp->right && temp->right->col > add->col) { add->right = temp->right; temp->right = add; break; } if(temp->right == NULL) { temp->right = add; break; } temp = temp->right; } } else pRow[row] = add; if(pCol[col]) { Node * temp = pCol[col]; //链接到列上 while(temp) { //此处排序,若按从左至右,从上至下的顺序进行,则已有序 if(temp->down && temp->down->row > add->row) { add->down = temp->down; temp->down = add; break; } if(temp->down == NULL) { temp->down = add; break; } temp = temp->down; } } else pCol[col] = add; } void Cross::Multiplication(const Cross & multiplicator, Cross & crossResult)//外部确保此处crossResult中没有数据 { for(int i = 0; i < m_row; ++i) { for(int j = 0; j < multiplicator.m_col; ++j) { int count = 0; //crossResult中(i, j)位置的值 Node * pR = pRow[i]; Node * pC = multiplicator.pCol[j]; while(pR && pC) { if(pR->col == pC->row) { count += (pR->num * pC->num); pR = pR->right; pC = pC->down; } else if(pR->col < pC->row) { pR = pR->right; } else { pC = pC->down; } }// end of while if(count > 0) { crossResult.Add(i, j, count); } } } } void Cross::GetMatrix(int * arr) { memset(arr, 0, sizeof(int) * m_row * m_col); for(int i = 0; i < m_row; ++i) { Node * temp = pRow[i]; while(temp) { arr[i * m_row + temp->col] = temp->num; temp = temp->right; } } } void Cross::Show() { for(int i = 0; i < m_row; ++i) { Node * temp = pRow[i]; while(temp) { cout << "(" << temp->row << "," << temp->col << "," << temp->num << ")" << '\t'; temp = temp->right; } cout << endl; } } //mm.cpp #include "cross.h" int main() { int num1[5][5] = { {1,0,0,0,0}, {0,2,0,0,0}, {0,3,0,0,0}, {0,0,0,4,0}, {0,0,0,0,5}}; int num2[5][5] = { {1,0,0,0,0}, {2,0,0,0,0}, {0,3,0,0,0}, {0,0,0,0,4}, {0,0,5,0,0}}; int num3[5][5]; Cross cr1(5, 5); cr1.GetChainFromMatrix((const int *)num1); cr1.Show(); cout << endl; Cross cr2(5, 5); cr2.GetChainFromMatrix((const int *)num2); cr2.Show(); cout << endl; Cross cr3(5, 5); cr1.Multiplication(cr2, cr3); cr3.Show(); cout << endl; cr3.GetMatrix((int *)num3); for(int i = 0; i < 5; ++i) { for(int j = 0; j < 5; ++j) { cout << num3[i][j] << '\t'; } cout << endl; } return 0; } |