作者 | zhonglihao |
算法名 | 稀疏矩阵乘法 Sparse Matrix Multiplication |
分类 | 数据结构 |
复杂度 | O(n^2) |
形式与数据结构 | C++代码 一维结构体存储 |
特性 | 极简封装 不使用链表 不需要转置 计算过程容易理解 |
具体参考出处 | 《算法导论》(写的不想看) |
备注 |
|
// ConsoleApplication1.cpp : 定义控制台应用程序的入口点。
//
#include "stdafx.h"
#include "stdio.h"
#include "stdlib.h"
//稀疏矩阵存储结构体 第一个元素为矩阵头,包含行列长度,元素总个数
typedef struct
{
int row;
int col;
int element;
}sparse_mat;
void SparseMatrixRectPrint(sparse_mat* s_mat);
void SparseMatrixTriPrint(sparse_mat* s_mat);
sparse_mat* SparseMatrixMul(sparse_mat* s_mat_A, sparse_mat* s_mat_B);
int _tmain(int argc, _TCHAR* argv[])
{
int i, j, k;
const int mat_A_row = 4;
const int mat_A_col = 4;
const int mat_B_row = 4;
const int mat_B_col = 4;
//原矩阵
int mat_A[mat_A_row][mat_A_col] = { 1, 1, 0, 0,
0, 0, 1, 0,
0, 1, 0, 0,
0, 0, 1, 0 };
int mat_B[mat_B_row][mat_B_col] = { 1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1 };
//计算有效元素数量
int mat_A_ele_count = 0;
int mat_B_ele_count = 0;
for (i = 0; i < mat_A_row; i++)
{
for (j = 0; j < mat_A_col; j++)
{
if (mat_A[i][j] != 0) mat_A_ele_count++;
}
}
for (i = 0; i < mat_B_row; i++)
{
for (j = 0; j < mat_B_col; j++)
{
if (mat_B[i][j] != 0) mat_B_ele_count++;
}
}
//动态分配
sparse_mat* sparse_m_A = (sparse_mat*)malloc((mat_A_ele_count + 1)*sizeof(sparse_mat));
sparse_mat* sparse_m_B = (sparse_mat*)malloc((mat_B_ele_count + 1)*sizeof(sparse_mat));
//存入稀疏矩阵信息
sparse_m_A[0].row = mat_A_row;
sparse_m_A[0].col = mat_A_col;
sparse_m_A[0].element = mat_A_ele_count;
sparse_m_B[0].row = mat_B_row;
sparse_m_B[0].col = mat_B_col;
sparse_m_B[0].element = mat_B_ele_count;
for (i = 0, mat_A_ele_count = 0; i < mat_A_row; i++)
{
for (j = 0; j < mat_A_col; j++)
{
if (mat_A[i][j] != 0)
{
mat_A_ele_count++;
sparse_m_A[mat_A_ele_count].element = mat_A[i][j];
sparse_m_A[mat_A_ele_count].row = i;
sparse_m_A[mat_A_ele_count].col = j;
}
}
}
for (i = 0, mat_B_ele_count = 0; i < mat_B_row; i++)
{
for (j = 0; j < mat_B_col; j++)
{
if (mat_B[i][j] != 0)
{
mat_B_ele_count++;
sparse_m_B[mat_B_ele_count].element = mat_B[i][j];
sparse_m_B[mat_B_ele_count].row = i;
sparse_m_B[mat_B_ele_count].col = j;
}
}
}
//打印原数组
SparseMatrixRectPrint(sparse_m_A);
SparseMatrixRectPrint(sparse_m_B);
//SparseMatrixTriPrint(sparse_m_A);
//SparseMatrixTriPrint(sparse_m_B);
//计算稀疏矩阵乘法
sparse_mat* sparse_m_C = (sparse_mat*)SparseMatrixMul(sparse_m_A, sparse_m_B);
SparseMatrixRectPrint(sparse_m_C);
system("Pause");
return 0;
}
//三元组稀疏矩阵乘法函数 极简封装 需要花费一点时间计算申请的内存 但是肯定比链表省空间啦
//Method Written By Zhonglihao
sparse_mat* SparseMatrixMul(sparse_mat* s_mat_A, sparse_mat* s_mat_B)
{
int i, j, k;
int s_mat_C_row = s_mat_A[0].row;
int s_mat_C_col = s_mat_B[0].col;
int s_mat_A_ele_count = s_mat_A[0].element;
int s_mat_B_ele_count = s_mat_B[0].element;
//判断是否能够相乘 或 有一个全为0 那就不用乘啦
if (s_mat_A[0].col != s_mat_B[0].row) return NULL;
if (s_mat_A_ele_count == 0 || s_mat_B_ele_count == 0)
{
sparse_mat* s_mat_C = (sparse_mat*)malloc((1)*sizeof(sparse_mat));
s_mat_C[0].row = s_mat_C_row;
s_mat_C[0].col = s_mat_C_col;
s_mat_C[0].element = 0;
return s_mat_C;
}
//申请一个长度为B列宽的缓存 两个用途 计算输出大小时做列封禁,计算相乘时做和缓存
int* col_buffer = (int*)malloc(s_mat_C_col*sizeof(int));
//清空缓存区
for (k = 0; k < s_mat_C_col; k++) col_buffer[k] = 0;
//判断需要输出的三元大小申请内存
int malloc_element_count = 0;
for (i = 1; i <= s_mat_A_ele_count; i++)
{
if (i >= 2 && s_mat_A[i].row != s_mat_A[i - 1].row) //换行解禁
{
for (k = 0; k < s_mat_C_col; k++) col_buffer[k] = 0;
}
for (j = 1; j <= s_mat_B_ele_count; j++)
{
if ((s_mat_A[i].col == s_mat_B[j].row) && col_buffer[s_mat_B[j].col] != 1)//没有列封禁
{
col_buffer[s_mat_B[j].col] = 1;//列封禁
malloc_element_count++;
}
}
}
sparse_mat* s_mat_C = (sparse_mat*)malloc((malloc_element_count + 1)*sizeof(sparse_mat));
s_mat_C[0].row = s_mat_C_row;
s_mat_C[0].col = s_mat_C_col;
s_mat_C[0].element = malloc_element_count;
int s_mat_C_ele_count = 0;//用于存入元素时做指针
//开始进行乘法相乘
for (k = 0; k < s_mat_C_col; k++) col_buffer[k] = 0;//清理列缓存
for (i = 1; i <= s_mat_A_ele_count; i++)
{
for (j = 1; j <= s_mat_B_ele_count; j++)
{
if (s_mat_A[i].col == s_mat_B[j].row)//有效用 压入缓存区
col_buffer[s_mat_B[j].col] += s_mat_A[i].element * s_mat_B[j].element;
}
//如果要换行或者是最后一行
if (((i != s_mat_A_ele_count) && (s_mat_A[i].row != s_mat_A[i + 1].row)) || i == s_mat_A_ele_count)
{
//扫描缓存组
for (k = 0; k < s_mat_C_col; k++)
{
//如果该点不是0 压入三元组 清零缓存
if (col_buffer[k] != 0)
{
s_mat_C_ele_count++;
s_mat_C[s_mat_C_ele_count].row = s_mat_A[i].row;
s_mat_C[s_mat_C_ele_count].col = k;
s_mat_C[s_mat_C_ele_count].element = col_buffer[k];
col_buffer[k] = 0;
}
}
}
}
//释放缓存 返回结果
free(col_buffer);
return s_mat_C;
}
//稀疏矩阵打印 按矩形打印 需要确定三元组按Z排列有序
void SparseMatrixRectPrint(sparse_mat* s_mat)
{
//获取行列信息
int i, j;
int row = s_mat[0].row;
int col = s_mat[0].col;
//打印元素递增 前提是三元组按照行列顺序排好,就只需要递增下标
int ele_count = 1;
//按矩阵扫描打印
for (i = 0; i < row; i++)
{
for (j = 0; j < col; j++)
{
if (i == s_mat[ele_count].row && j == s_mat[ele_count].col)
{
printf("%d\t", s_mat[ele_count].element);
ele_count++;
}
else
{
printf("0\t");
}
}//for
printf("\n");
}//for
//跳空换行 返回
printf("\n");
return;
}
//稀疏矩阵打印 按三元组结构打印
void SparseMatrixTriPrint(sparse_mat* s_mat)
{
int i, j;
int ele_count = s_mat[0].element;
//按顺序打印
for (i = 1; i <= ele_count; i++)
{
printf("%d\t%d\t%d\n", s_mat[i].row, s_mat[i].col, s_mat[i].element);
}
//跳空换行 返回
printf("\n");
return;
}