基于正交链表表示的稀疏矩阵运算实现

稀疏矩阵的正交链表示由head节点行链表,列链表,行指针数组,列指针数组以及正交链组成,head节点同时是行链表和列链表的头结点,行指针数组和列指针数组元素指向行链表和列链表中对应节点,行链表列链表均和head构成双向循环链表,正交链中稀疏矩阵每一个非零节点由行列值,关键字,指向行链表上前驱后继的指针和指向列链表上前驱后继的指针组成
下面的代码基于正交链表示实现了以下操作
稀疏矩阵加法(不改变原矩阵)
稀疏矩阵加法(用相加结果覆盖原矩阵)
稀疏矩阵乘法稀疏矩阵转置(结果存放在新矩阵)
稀疏矩阵转置(原矩阵被转置结果覆盖)
C++代码:

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

template <typename T>
struct OrthogonalChainsNode
{
    enum class type {head, row_head, col_head, element} flag;
    size_t row;  //row,col全部从1开始
    size_t col;
    T value;
    OrthogonalChainsNode* row_pre = nullptr;
    OrthogonalChainsNode* row_post = nullptr;
    OrthogonalChainsNode* col_pre = nullptr;
    OrthogonalChainsNode* col_post = nullptr;
    OrthogonalChainsNode(size_t r, size_t c, const T& v, type f) :row(r), col(c), value(v), flag(f) {}
    OrthogonalChainsNode() = default;
    OrthogonalChainsNode(const OrthogonalChainsNode& copy) :flag(copy.flag), row(copy.row), col(copy.col), value(copy.value) {}
};

template <typename T>
void insert_before_col(OrthogonalChainsNode<T>* goal, OrthogonalChainsNode<T>* inserted)
{
    inserted->col_pre = goal->col_pre;
    goal->col_pre->col_post = inserted;
    inserted->col_post = goal;
    goal->col_pre = inserted;
}

template <typename T>
void insert_before_row(OrthogonalChainsNode<T>* goal, OrthogonalChainsNode<T>* inserted)
{
    inserted->row_pre = goal->row_pre;
    goal->row_pre->row_post = inserted;
    inserted->row_post = goal;
    goal->row_pre = inserted;
}

template <typename T>
void insert_after_col(OrthogonalChainsNode<T>* goal, OrthogonalChainsNode<T>* inserted)
{
    inserted->col_post = goal->col_post;
    goal->col_post->col_pre = inserted;
    inserted->col_pre = goal;
    goal->col_post = inserted;
}

template <typename T>
void insert_after_row(OrthogonalChainsNode<T>* goal, OrthogonalChainsNode<T>* inserted)
{
    inserted->row_post = goal->row_post;
    goal->row_post->row_pre = inserted;
    inserted->row_pre = goal;
    goal->row_post = inserted;
}

template <typename T>
void deleteRow(OrthogonalChainsNode<T>* goal)
{
    goal->row_pre->row_post = goal->row_post;
    goal->row_post->row_pre = goal->row_pre;
}

template <typename T>
void deleteCol(OrthogonalChainsNode<T>* goal)
{
    goal->col_pre->col_post = goal->col_post;
    goal->col_post->col_pre = goal->col_pre;
}

template <typename T>
void deleteRowColElement(OrthogonalChainsNode<T>* goal)
{
    deleteRow(goal);
    deleteCol(goal);
}

template <typename T>
struct TripleElemet
{
    size_t row;   //row col全部从1开始
    size_t col;
    T value;
};

template <typename T>
struct Triple
{
    size_t row;  //稀疏矩阵对应原矩阵总行数总列数
    size_t col;
    vector<TripleElemet<T>> Tlist;
};


template <typename T>
class SparseMatrix
{
public:
    friend void toMatrix(const SparseMatrix<long long>& test, vector<vector<long long>>& r);
    void add(const SparseMatrix<T>& be_added);
    void addGenNewSparseMatrix(const SparseMatrix<T>& be_added, SparseMatrix<T>& result);
    void multiply(const SparseMatrix<T>& be_multiply, SparseMatrix<T>& result);
    void transpose(SparseMatrix<T>& result);
    void transpose();
    SparseMatrix(size_t row, size_t col, const T& zero);   //result参数的稀疏矩阵必须用该构造函数初始化
    SparseMatrix(const SparseMatrix<T>& copy);
    SparseMatrix(const vector<vector<T>>& M, const T& zero);
    SparseMatrix(const Triple<T>& T_M, const T& zero);
    ~SparseMatrix();
private:
    void insert_before_in_row_col_list(OrthogonalChainsNode<T>* goal, OrthogonalChainsNode<T>* inserted, vector<OrthogonalChainsNode<T>*>& col_list);
    void deleteInRowColList(OrthogonalChainsNode<T>* goal, vector<OrthogonalChainsNode<T>*>& col_list);
    void initSparseMatrix(size_t row, size_t col, const T& z);
    void insert_in_col_list(OrthogonalChainsNode<T>* inserted, vector<OrthogonalChainsNode<T>*>& col_list);
    OrthogonalChainsNode<T>* head;
    vector<OrthogonalChainsNode<T>*> row_ptr;
    vector<OrthogonalChainsNode<T>*> col_ptr;  //行列数存放在head节点
    T zero;
};

template <typename T>
SparseMatrix<T>::SparseMatrix(const vector<vector<T>>& M, const T& zero)
{
    (this->SparseMatrix<T>::SparseMatrix)(M.size(), M[0].size(), zero);
    vector<OrthogonalChainsNode<T>*> col_list = col_ptr;
    for (size_t i = 0; i < M.size(); ++i)
    {
        OrthogonalChainsNode<T>* left_l = row_ptr[i];
        for (size_t j = 0; j < M[0].size(); ++j)
        {
            if (M[i][j] != zero)
            {
                OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(i + 1, j + 1, M[i][j], OrthogonalChainsNode<T>::type::element);
                insert_before_in_row_col_list(left_l, t, col_list);
            }
        }
    }
}

template <typename T>
SparseMatrix<T>::SparseMatrix(const Triple<T>& T_M, const T& zero)
{
    (this->SparseMatrix<T>::SparseMatrix)(T_M.row, T_M.col, zero);
    vector<OrthogonalChainsNode<T>*> col_list = col_ptr;
    for (size_t i = 0; i < T_M.Tlist.size(); ++i)
    {
        size_t cur_row = T_M.Tlist[i].row;
        OrthogonalChainsNode<T>* left_l = row_ptr[cur_row];
        OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(cur_row, T_M.Tlist[i].col, T_M.Tlist[i].value, OrthogonalChainsNode<T>::type::element);
        insert_before_in_row_col_list(left_l, t, col_list);
    }
}

template <typename T>
void SparseMatrix<T>::transpose() 
{
    size_t col = head->col;
    size_t row = head->row;
    swap(head->row, head->col);
    if (row > col)
    {
        col_ptr.resize(row);
        for (size_t i = col; i < row; ++i)
        {
            OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>();
            col_ptr[i] = t;
            t->col = i + 1;
            t->row = 0;
            t->col_post = t;
            t->col_pre = t;
            t->flag = OrthogonalChainsNode<T>::type::col_head;
            insert_before_row(head, t);
        }
        row = col;
    }
    else if (row < col)
    {
        row_ptr.resize(col);
        for (size_t i = row; i < col; ++i)
        {
            OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>();
            row_ptr[i] = t;
            t->row = i + 1;
            t->col = 0;
            t->row_post = t;
            t->row_pre = t;
            t->flag = OrthogonalChainsNode<T>::type::row_head;
            insert_before_col(head, t);
        }
        row = col;
    }

    for (size_t j = 0; j < col; ++j)
    {
        OrthogonalChainsNode<T>* run = col_ptr[j]->col_post;
        OrthogonalChainsNode<T>* pre = row_ptr[j];
        while (run != col_ptr[j])
        {
            deleteRowColElement(run);
            insert_after_row(pre, run);
            pre = run;
            swap(run->row, run->col);
            run = run->col_post;
        }
    }

    vector<OrthogonalChainsNode<T>*> col_list(col_ptr.begin(), col_ptr.begin() + head->col);
    for (size_t i = 0; i < row; ++i)
    {
        OrthogonalChainsNode<T>* run = row_ptr[i]->row_post;
        while (run != row_ptr[i])
        {
            insert_in_col_list(run, col_list);
            run = run->row_post;
        }       
    }

    if (head->row < head->col)
    {
        for (size_t i = head->row; i < head->col; ++i)
        {
            delete row_ptr[i];
        }
        row_ptr.resize(head->row);
    }
    else if (head->row > head->col)
    {
        for (size_t i = head->col; i < head->row; ++i)
        {
            delete col_ptr[i];
        }
        col_ptr.resize(head->col);
    }
}

template <typename T>
void SparseMatrix<T>::transpose(SparseMatrix<T>& result)
{
    size_t col = head->col;
    vector<OrthogonalChainsNode<T>*> col_list = result.col_ptr;
    for (size_t j = 0; j < col; ++j)
    {
        OrthogonalChainsNode<T>* run = col_ptr[j]->col_post;
        OrthogonalChainsNode<T>* left_l = result.row_ptr[j];
        while (run != col_ptr[j])
        {
            OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(run->col, run->row, run->value, run->flag);
            insert_before_in_row_col_list(left_l, t, col_list);
            run = run->col_post;
        }
    }
}

template <typename T>
SparseMatrix<T>::~SparseMatrix()
{
    size_t row = head->row;
    for (size_t i = 0; i < row; ++i)
    {
        OrthogonalChainsNode<T>* run = row_ptr[i]->row_post;
        while (run != row_ptr[i])
        {
            deleteRowColElement(run);
            OrthogonalChainsNode<T>* t = run;
            t = t->row_post;
            delete run;
            run = t;
        }
    }

    for (size_t i = 0; i < row; ++i)
    {
        delete row_ptr[i];
    }

    size_t col = head->col;
    for (size_t i = 0; i < col; ++i)
    {
        delete col_ptr[i];
    }
    delete head;
}

template <typename T>
SparseMatrix<T>::SparseMatrix(const SparseMatrix<T>& copy)
{
    (this->SparseMatrix<T>::SparseMatrix)(copy.head->row, copy.head->col, copy.zero);
    size_t row = head->row;
    vector<OrthogonalChainsNode<T>*> col_list = col_ptr;
    for (size_t i = 0; i < row; ++i)
    {
        OrthogonalChainsNode<T>* left_l = row_ptr[i];
        OrthogonalChainsNode<T>* run = copy.row_ptr[i]->row_post;
        while (run != copy.row_ptr[i])
        {
            OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(*run);
            insert_before_in_row_col_list(left_l, t, col_list);
            run = run->row_post;
        }
    }
}

template <typename T>
void SparseMatrix<T>::initSparseMatrix(size_t row, size_t col, const T& z)
{
    head = new OrthogonalChainsNode<T>();
    head->row = row;
    head->col = col;
    head->flag = OrthogonalChainsNode<T>::type::head;
    head->col_post = head;
    head->col_pre = head;

    for (size_t i = 0; i < row; ++i)
    {
        OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>();
        row_ptr[i] = t;
        t->row = i + 1;
        t->col = 0;
        t->flag = OrthogonalChainsNode<T>::type::row_head;
        t->row_post = t;
        t->row_pre = t;
        insert_before_col(head, t);
    }

    head->row_post = head;
    head->row_pre = head;
    for (size_t j = 0; j < col; ++j)
    {
        OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>();
        col_ptr[j] = t;
        t->col = j + 1;
        t->row = 0;
        t->flag = OrthogonalChainsNode<T>::type::col_head;
        t->col_post = t;
        t->col_pre = t;
        insert_before_row(head, t);
    }
}

template <typename T>
SparseMatrix<T>::SparseMatrix(size_t row, size_t col, const T& z) :row_ptr(row), col_ptr(col), zero(z)
{
    initSparseMatrix(row, col, z);
}

template <typename T>
void SparseMatrix<T>::multiply(const SparseMatrix<T>& be_multiply, SparseMatrix<T>& result)  //检查row col确定相乘的矩阵是否合法
{
    size_t col = result.head->col;
    size_t row = result.head->row;
    vector<OrthogonalChainsNode<T>*> col_list = result.col_ptr;    
    for (size_t i = 0; i < row; ++i)
    {
        OrthogonalChainsNode<T>* run = row_ptr[i]->row_post;
        if (run != row_ptr[i])
        {
            vector<T> row_array(col, zero);
            bool all_empty = true;
            while (run != row_ptr[i])
            {
                OrthogonalChainsNode<T>* p = be_multiply.row_ptr[run->col - 1]->row_post;
                if (p != be_multiply.row_ptr[run->col - 1])
                {
                    all_empty = false;
                    while (p != be_multiply.row_ptr[run->col - 1])
                    {
                        T m = run->value * p->value;
                        row_array[p->col - 1] += m;
                        p = p->row_post;
                    }
                }
                run = run->row_post;
            }

            if (all_empty == false)
            {
                for (size_t j = 0; j < col; ++j)
                {
                    if (row_array[j] != zero)
                    {
                        OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(i + 1, j + 1, row_array[j], OrthogonalChainsNode<T>::type::element);
                        insert_before_in_row_col_list(result.row_ptr[i], t, col_list);
                    }
                }
            }
        }
    }
}

template <typename T>
void SparseMatrix<T>::insert_in_col_list(OrthogonalChainsNode<T>* inserted, vector<OrthogonalChainsNode<T>*>& col_list)
{
    insert_after_col(col_list[inserted->col - 1], inserted);
    col_list[inserted->col - 1] = inserted;
}

template <typename T>
void SparseMatrix<T>::deleteInRowColList(OrthogonalChainsNode<T>* goal, vector<OrthogonalChainsNode<T>*>& col_list)
{
    deleteRowColElement(goal);
    col_list[goal->col - 1] = goal->col_pre;
}

template <typename T>
void SparseMatrix<T>::insert_before_in_row_col_list(OrthogonalChainsNode<T>* goal, OrthogonalChainsNode<T>* inserted, vector<OrthogonalChainsNode<T>*> &col_list)
{
    insert_before_row(goal, inserted);
    insert_in_col_list(inserted, col_list);
}

template <typename T>
void SparseMatrix<T>::addGenNewSparseMatrix(const SparseMatrix<T>& be_added, SparseMatrix<T>& result)  //result的row col list head colptr row_ptr必须初始化完毕
{
    vector<OrthogonalChainsNode<T>*> col_list = result.col_ptr;
    size_t row = head->row;
    for (size_t i = 0; i < row; ++i)
    {
            OrthogonalChainsNode<T>* left_r = result.row_ptr[i];
            OrthogonalChainsNode<T>* left = row_ptr[i]->row_post;
            OrthogonalChainsNode<T>* right = be_added.row_ptr[i]->row_post;
            while (left != row_ptr[i] && right != be_added.row_ptr[i])
            {
                if (left->col == right->col)
                {
                    T sum = left->value + right->value;
                    if (sum != zero)
                    {
                        OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(i + 1, left->col, sum, OrthogonalChainsNode<T>::type::element);
                        insert_before_in_row_col_list(left_r, t, col_list);
                    }
                    left = left->row_post;
                    right = right->row_post;
                }
                else if (left->col < right->col)
                {
                    OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(*left);
                    insert_before_in_row_col_list(left_r, t, col_list);
                    left = left->row_post;
                }
                else
                {
                    OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(*right);
                    insert_before_in_row_col_list(left_r, t, col_list);
                    right = right->row_post;
                }
            }

            while (left != row_ptr[i])
            {
                OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(*left);
                insert_before_in_row_col_list(left_r, t, col_list);
                left = left->row_post;
            }

            while (right != be_added.row_ptr[i])
            {
                OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(*right);
                insert_before_in_row_col_list(left_r, t, col_list);
                right = right->row_post;
            }
    }
}

template <typename T>
void SparseMatrix<T>::add(const SparseMatrix<T>& be_added)  //检查行列数是否相等
{
    vector<OrthogonalChainsNode<T>*> col_list = col_ptr;
    size_t row = head->row;
    for (size_t i = 0; i < row; ++i)
    {
        OrthogonalChainsNode<T>* left = row_ptr[i]->row_post;
        OrthogonalChainsNode<T>* right = be_added.row_ptr[i]->row_post;
        while (left != row_ptr[i] && right != be_added.row_ptr[i])
        {
            if (left->col == right->col)
            {
                T sum = left->value + right->value;
                if (sum == zero)
                {
                    deleteInRowColList(left, col_list);
                    OrthogonalChainsNode<T>* t = left;
                    t = t->row_post;
                    delete left;
                    left = t;
                }
                else
                {
                    col_list[left->col - 1] = left;
                    left->value = sum;
                    left = left->row_post;
                }
                right = right->row_post;
            }
            else if (left->col < right->col)
            {
                col_list[left->col - 1] = left;
                left = left->row_post;
            }
            else
            {
                OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(*right);
                insert_before_in_row_col_list(left, t, col_list);
                right = right->row_post;
            }
        }

        while (left != row_ptr[i])
        {
            col_list[left->col - 1] = left;
            left = left->row_post;
        }

        while (right != be_added.row_ptr[i])
        {
            OrthogonalChainsNode<T>* t = new OrthogonalChainsNode<T>(*right);
            insert_before_in_row_col_list(left, t, col_list);
            right = right->row_post;
        }
    }
}


void genSeq(vector<int>& seq, vector<vector<int>>& input, size_t level)
{
    if (level == seq.size())
        input.push_back(seq);
    else
    {
        seq[level] = 1;
        genSeq(seq, input, level + 1);
        seq[level] = 0;
        genSeq(seq, input, level + 1);
    }
}

bool equal(const vector<vector<long long>>& left, const vector<vector<long long>>& right)
{
    for (size_t i = 0; i < left.size(); ++i)
    {
        for (size_t j = 0; j < left[0].size(); ++j)
        {
            if (left[i][j] != right[i][j])
                return false;
        }
    }
    return true;
}

void add(const vector<vector<long long>>& left, const vector<vector<long long>>& right, vector<vector<long long>>& r)
{
    for (size_t i = 0; i < left.size(); ++i)
    {
        for (size_t j = 0; j < left[0].size(); ++j)
        {
            r[i][j] = left[i][j] + right[i][j];
        }
    }
}

void multiply(const vector<vector<long long>>& left, const vector<vector<long long>>& right, vector<vector<long long>>& r)
{
    for (size_t i = 0; i < left.size(); ++i)
    {
        for (size_t j = 0; j < left[0].size(); ++j)
        {
            for (size_t k = 0; k < right[0].size(); ++k)
            {
                r[i][k] += left[i][j] * right[j][k];
            }    
        }
    }
}

void transpose(const vector<vector<long long>>& left, vector<vector<long long>>& r)
{
    for (size_t i = 0; i < left.size(); ++i)
    {
        for (size_t j = 0; j < left[0].size(); ++j)
        {
            r[j][i] = left[i][j];
        }
    }
}

void toMatrix(const SparseMatrix<long long>& test, vector<vector<long long>>& r)
{
    for (size_t i = 0; i < test.head->row; ++i)
    {
        OrthogonalChainsNode<long long>* left = test.row_ptr[i]->row_post;
        while (left != test.row_ptr[i])
        {
            r[i][left->col - 1] = left->value;
            left = left->row_post;
        }
    }
}

int main()
{
    vector<int> row(3);
    vector<int> col(2);
    vector<vector<int>> row_set;
    vector<vector<int>> col_set;
    genSeq(row, row_set, 0);
    genSeq(col, col_set, 0);
    for (size_t i = 0; i < row_set.size(); ++i)
    {
        for (size_t j = 0; j < col_set.size(); ++j)
        {
            vector<vector<long long>> input(3, vector<long long>(2, 0));
            for (size_t k = 0; k < 3; ++k)
            {
                if (row_set[i][k] == 0)
                    continue;
                for (size_t m = 0; m < 2; ++m)
                {
                    if (col_set[j][m] == 0)
                        continue;
                    input[k][m] = 2;
                }
            }

            SparseMatrix<long long> test(input, 0ll);
            for (size_t _i = 0; _i < row_set.size(); ++_i)
            {
                for (size_t _j = 0; _j < col_set.size(); ++_j)
                {
                    vector<vector<long long>> input2(3, vector<long long>(2, 0));
                    for (size_t k = 0; k < 3; ++k)
                    {
                        if (row_set[_i][k] == 0)
                            continue;
                        for (size_t m = 0; m < 2; ++m)
                        {
                            if (col_set[_j][m] == 0)
                                continue;
                            input2[k][m] = -1;
                        }
                    }
                    SparseMatrix<long long> test2(input2, 0ll);
                    SparseMatrix<long long> test5(test2);
                    SparseMatrix<long long> test3(3, 2, 0);
                    test.addGenNewSparseMatrix(test2, test3);
                    vector<vector<long long>> input3(3, vector<long long>(2, 0));
                    toMatrix(test3, input3);
                    vector<vector<long long>> input4(3, vector<long long>(2, 0));
                    add(input, input2, input4);
                    if (equal(input3, input4) == false)
                    {
                        cout << "ERROR:addGenNewSparseMatrix相加结果不正确";
                        exit(-1);
                    }
                    else
                    {
                        cout << "addGenNewSparseMatrix相加结果正确" << endl;
                    }
                    SparseMatrix<long long> test4(test);
                    test4.add(test2);
                    vector<vector<long long>> input5(3, vector<long long>(2, 0));
                    toMatrix(test4, input5);
                    if (equal(input5, input4) == false)
                    {
                        cout << "ERROR:add相加结果不正确";
                        exit(-1);
                    }
                    else
                    {
                        cout << "add相加结果正确" << endl;
                    }
                    SparseMatrix<long long> test6(2, 3, 0);
                    test5.transpose(test6);
                    vector<vector<long long>> input6(2, vector<long long>(3, 0));
                    toMatrix(test6, input6);
                    vector<vector<long long>> input7(2, vector<long long>(3, 0));
                    transpose(input2, input7);
                    if (equal(input6, input7) == false)
                    {
                        cout << "ERROR:transpose带参转置错误";
                        exit(-1);
                    }
                    else
                    {
                        cout << "transpose带参转置正确" << endl;
                    }
                    test5.transpose();
                    vector<vector<long long>> input10(2, vector<long long>(3, 0));
                    toMatrix(test5, input10);
                    if (equal(input10, input7) == false)
                    {
                        cout << "ERROR:transpose无参转置错误";
                        exit(-1);
                    }
                    else
                    {
                        cout << "transpose无参转置正确" << endl;
                    }

                    SparseMatrix<long long> test7(3, 3, 0);
                    test.multiply(test6, test7);
                    vector<vector<long long>> input8(3, vector<long long>(3, 0));
                    toMatrix(test7, input8);
                    vector<vector<long long>> input9(3, vector<long long>(3, 0));
                    multiply(input, input6, input9);
                    if (equal(input8, input9) == false)
                    {
                        cout << "ERROR:相乘错误";
                        exit(-1);
                    }
                    else
                    {
                        cout << "相乘正确" << endl;
                    }
                }
            }
        }
    }
    cout << "恭喜全部正确!" << endl;
    return 0;
}



  • 8
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值