稀疏矩阵乘法与加法


稀疏矩阵的三元组储存,模式来自于严蔚敏的数据结构教材

支持乱序输入


<pre name="code" class="cpp">//      whn6325689
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <climits>
#include <complex>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <bitset>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <ctime>
#include <set>
#include <map>
#include <cmath>


using namespace std;

typedef long long ll;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef complex<ld> point;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef vector<int> vi;

#define CLR(x,y) memset(x,y,sizeof(x))
#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define lowbit(x) (x&(-x))
#define MID(x,y) (x+((y-x)>>1))
#define eps 1e-9
#define INF 0x3f3f3f3f
#define LLINF 1LL<<62

template<class T>
inline bool read(T &n)
{
    T x = 0, tmp = 1;
    char c = getchar();
    while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
    if(c == EOF) return false;
    if(c == '-') c = getchar(), tmp = -1;
    while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
    n = x*tmp;
    return true;
}
template <class T>
inline void write(T n)
{
    if(n < 0)
    {
        putchar('-');
        n = -n;
    }
    int len = 0,data[20];
    while(n)
    {
        data[len++] = n%10;
        n /= 10;
    }
    if(!len) data[len++] = 0;
    while(len--) putchar(data[len]+48);
}
//-----------------------------------

#define OK 1
#define ERROR 0
#define MAXSIZE 100
#define MAXRC 10


struct Node
{
    int x,y;
    int e ;
};

struct Matrix
{
    Node data[MAXSIZE+1];
    int rpos[MAXRC+1];
    int mu,nu,tu;
};

bool cmp(Node a,Node b)
{
    if(a.x==b.x)
        return a.y<b.y;
    return a.x<b.x;
}

int create(Matrix *matrix)
//     创建一个稀疏矩阵;
//     输入行数、列数,支持乱序输入三元组,并计数;
//     以行为主序进行重新排列,并记录每行起始位置于matrix->rpos[row];
//     若非零元超过 MAXSIZE或行数超过MAXRC,则返回ERROR,否则OK;
{
    int num=0,p,q,min,temp;             //     中间变量;
    int row;
    printf("input the total row and col:\n");
    read(matrix->mu),read(matrix->nu);      //     输入行数、列数;
    if(matrix->mu > MAXRC)
        return ERROR;
    printf("row col val:(end with zero)\n");
    read(matrix->data[num+1].x),read(matrix->data[num+1].y),read(matrix->data[num+1].e);
    while(matrix->data[num+1].x)      //     乱序输入三元组;
    {
        if(++num>MAXSIZE)
            return ERROR;
        read(matrix->data[num+1].x),read(matrix->data[num+1].y),read(matrix->data[num+1].e);
    }
    matrix->tu=num;                  //     num的值即为此矩阵的非零元个数;
    sort(matrix->data+1,matrix->data+matrix->tu+1,cmp);
    /*
    for(p=1; p<=matrix->tu-1; ++p)          //     按行为主序依次重新排列非零元
    {
        min=p;           //     使较小的行数、列数的元的序号min为当前值p;
        for(q=p+1; q<=matrix->tu; ++q)         //     开始依次比较;
        {
            if(matrix->data[min].x>matrix->data[q].x||(matrix->data[min].x==matrix->data[q].x&&matrix->data[min].y>matrix->data[q].y))
                min=q;           //     在乱序的三元表中,始终保证min是较小的行列数的序号;
        }
        temp=matrix->data[min].x;                         //     交换行值;
        matrix->data[min].x=matrix->data[p].x;
        matrix->data[p].x=temp;
        temp=matrix->data[min].y;                         //     交换列值;
        matrix->data[min].y=matrix->data[p].y;
        matrix->data[p].y=temp;
        temp=matrix->data[min].e;                        //     交换元素值;
        matrix->data[min].e=matrix->data[p].e;
        matrix->data[p].e=temp;
    }
    */
    for(row=1,num=0; row<=matrix->mu; ++row)          //     记录matrix->rpos[row];
    {
        matrix->rpos[row]=num+1;        //    逆记录
        while(matrix->data[num+1].x==row)
            ++num;
    }
    return OK;
}


void print(Matrix *matrix)
//     输入矩阵,打印出矩阵的行数、列数、非零元个数,以及整个矩阵;
{
    int row,col;
    int num=0;
    printf("\nrow:%d col:%d number:%d\n",matrix->mu,matrix->nu,matrix->tu);
    for(row=1; row<=matrix->mu; ++row)
    {
        for(col=1; col<=matrix->nu; ++col)
        {
            if(num+1<=matrix->tu&&matrix->data[num+1].x==row&&matrix->data[num+1].y==col)
            {
                ++num;
                printf("%4d",matrix->data[num].e);            /*    当扫描到非零元的行列值与之相等时,输出其值    */
            }
            else
                printf("%4d",0);           /*    没有非零元的地方补0 */
        }
        printf("\n");           /*    每行输入完毕后,换行         */
    }
}




Matrix* mult(Matrix *A,Matrix *B)
//     输入两个稀疏矩阵M和N,并初始化Q,然后计算M*N的值赋给Q;
//     如果M->mu!=N->nu或列数大于MAXRC或者计算出的非零元个数大于MAXSIZE,都返回ERROR,否则OK;
//     计算过程如下:
//     1.    由于矩阵M和Q的行数相等并且C语言以行为主序进行存储,所以以M进行逐行的扫描。
//     2.    使Q的此行逻辑表的序号等于其非零元个数Q.tu+1,以表示其行的首个元素的序号。
//     3.    从行中找到M的非零元,并以它的列值为N的行号,对N进行行的扫描,若存在,则依次计算它们,并把其值累加到一个以N中这个对应非零元的列值为序号的临时数组ctemp[ccol]中。
//     4.    在M的当前行完成扫描后,将ctemp[ccol]不为0的值,压入到Q矩阵的三元组,累加++Q.tu,若Q.tu大于了MAXSIZE,这返回ERROR。

{
    Matrix *Q;
    if(!(Q=(Matrix *)malloc(sizeof(Matrix))))
        exit(ERROR);
    int arow,brow,ccol;
    int * ctemp;    /*    以N的列值为序号的临时数组   */
    int tp,p,tq,q;          /*    中间变量       */
    if(A->nu!=B->mu)
        return NULL;
    Q->mu=A->mu;           /*    初始化Q       */
    Q->nu=B->nu;
    Q->tu=0;
    if(!(ctemp=(int *)malloc((B->nu+1)*sizeof(int))))     /*    动态建立累加器    */
        exit(ERROR);
    if(A->tu*B->tu!=0)             /*    Q是非零矩阵       */
    {
        for(arow=1; arow<=A->mu; ++arow)  /*    逐行扫描       */
        {
            for(ccol=1; ccol<=B->nu; ++ccol)
                ctemp[ccol]=0;              /*    初始化累加器       */
            Q->rpos[arow]=Q->tu+1;
            if(arow<A->mu)
                tp=A->rpos[arow+1];    /*    tp是M下一行的序号   */
            else
                tp=A->tu+1;
            for(p=A->rpos[arow]; p<tp; ++p) /*    从M的当前行找到元素      */
            {
                brow=A->data[p].y;              /*    对应元在N中的行号   */
                if(brow<B->mu)
                    tq=B->rpos[brow+1];    /*    tq是N下一行的行号   */
                else
                    tq=B->tu+1;
                for(q=B->rpos[brow]; q<tq; ++q) /*    以M的对应元的列号为N的行号进行扫描      */
                {
                    ccol=B->data[q].y;         /*    提取对应元的列号       */
                    ctemp[ccol]+=A->data[p].e*B->data[q].e;
                    /*    两个对应元的值相乘并累加到以列号为序号的累加器中       */
                }
            }
            for(ccol=1; ccol<=Q->nu; ++ccol) /*    将此行非零元压缩入Q中   */
            {
                if(ctemp[ccol])
                {
                    if(++Q->tu>MAXSIZE)
                        return NULL;
                    Q->data[Q->tu].x=arow;
                    Q->data[Q->tu].y=ccol;
                    Q->data[Q->tu].e=ctemp[ccol];
                }
            }
        }
    }
    return Q;
}

Matrix* add(Matrix *A,Matrix *B)     //两稀疏矩阵相加
//     输入两个稀疏矩阵M和N,并初始化Q,然后计算M+N的值赋给Q;
//     如果M->mu!=N->nu或列数大于MAXRC或者计算出的非零元个数大于MAXSIZE,都返回ERROR,否则OK;
{
    Matrix *Q;
    if(!(Q=(Matrix *)malloc(sizeof(Matrix))))
        exit(ERROR);
    Q->mu=A->mu;
    Q->nu=A->nu;
    Q->tu=0;
    int a,b,k;
    a=b=1;
    for(k=1; k<=A->mu; k++)
    {
        while(A->data[a].x==k && B->data[b].x==k)     //两结点属于同一行
        {
            if(A->data[a].y<B->data[b].y) //A中下标a的结点小于B中下标为b的结点的列数
            {
                if(++Q->tu>MAXSIZE)
                    return NULL;
                Q->data[Q->tu].e=A->data[a].e;
                Q->data[Q->tu].x=A->data[a].x;
                Q->data[Q->tu].y=A->data[a].y;
                a++;
            }
            else if(A->data[a].y==B->data[b].y)     //两结点属于同一列
            {
                int temp=A->data[a].e+B->data[b].e;
                if(temp==0)         //两结点数值相加等于0
                {
                    a++;
                    b++;
                }
                else        //两结点数值相加不等于0
                {
                    if(++Q->tu>MAXSIZE)
                        return NULL;
                    Q->data[Q->tu].e=temp;
                    Q->data[Q->tu].x=A->data[a].x;
                    Q->data[Q->tu].y=A->data[a].y;
                    a++;
                    b++;
                }
            }
            else           //A中下标a的结点大于B中下标为b的结点的列数
            {
                if(++Q->tu>MAXSIZE)
                    return NULL;
                Q->data[Q->tu].e=B->data[b].e;
                Q->data[Q->tu].x=B->data[b].x;
                Q->data[Q->tu].y=B->data[b].y;
                b++;
            }
        }
        while(A->data[a].x==k)      //只有A中剩下行数为k的未处理结点
        {
            if(++Q->tu>MAXSIZE)
                return NULL;
            Q->data[Q->tu].e=A->data[a].e;
            Q->data[Q->tu].x=A->data[a].x;
            Q->data[Q->tu].y=A->data[a].y;
            a++;
        }
        while(B->data[b].x==k)     //只有B中剩下行数为k的未处理结点
        {
            if(++Q->tu>MAXSIZE)
                return NULL;
            Q->data[Q->tu].e=B->data[b].e;
            Q->data[Q->tu].x=B->data[b].x;
            Q->data[Q->tu].y=B->data[b].y;
            b++;
        }
    }
    return Q;
}



int main()
{
    freopen("data.txt","r",stdin);
    Matrix * M,* N,* Q;
    if(!(M=(Matrix *)malloc(sizeof(Matrix))))
        exit(ERROR);
    if(!(N=(Matrix *)malloc(sizeof(Matrix))))
        exit(ERROR);
    if(create(M)&&create(N))
    {
        printf("\nput out M:\n");
        print(M);           /*    打印出M       */
        printf("\nput out N:\n");
        print(N);            /*    打印出N       */
        if((Q=mult(M,N)))
        {
            printf("\n M  *  N  :\n");
            print(Q);     /*    计算结果       */
            free(Q);
        }
        else
            printf("M.mu and N.nu are not mathing\n");
        if((Q=add(M,N)))
        {
            printf("\n M  +  N  :\n");
            print(Q);     /*    计算结果       */
            free(Q);
        }
        else
        {
            printf("M.mu and N.nu are not mathing\n");
        }
    }
    else
        printf("input error.\n");
    return 0;
}


 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值