CSR和COO实现spgemm

1) COO实现spgemm 

#include <iostream>
#include <vector>
#include <unordered_map>

// 定义稀疏矩阵的元素
struct Element {
    int row, col;
    double value;
};

// 稀疏矩阵乘法函数
std::vector<Element> spgemm(const std::vector<Element>& A, const std::vector<Element>& B, int rowsA, int colsB) {
    //创建可自由缩放的二维矩阵数组
    std::unordered_map<int, std::unordered_map<int, double>> result;

    std::unordered_map<int, std::vector<Element>> B_col;
    for (const auto& elem : B) {
        B_col[elem.col].push_back(elem);// 将B矩阵转换为列主序, 为下面的判断做准备
    }

    // 进行乘法运算
    for (const auto& a : A) {
        //判断矩阵A的列在矩阵B中是否有对应的行相匹配
        if (B_col.find(a.col) != B_col.end()) {
            //a是向量A中的一个数据类型为element的元素, a.col代表元素element.col
            //B_col[a.col]: 确保矩阵B的行与矩阵A的列相匹配
            for (const auto& b : for (const auto& b : B_col[a.col]) {
                result[a.row][b.col] += a.value * b.value;
            }) {
                result[a.row][b.col] += a.value * b.value;
            }
        }
    }

    // 将结果转换为COO格式
    std::vector<Element> C;
    for (const auto& row : result) {
        for (const auto& col : row.second) {
            C.push_back({row.first, col.first, col.second});
        }
    }
    return C;
}

int main() {
    // 示例稀疏矩阵A和B
    std::vector<Element> A = {{0, 0, 1.0}, {0, 1, 2.0}, {1, 0, 3.0}};
    std::vector<Element> B = {{0, 0, 4.0}, {1, 0, 5.0}, {1, 1, 6.0}};

    // 计算A * B
    std::vector<Element> C = spgemm(A, B, 2, 2);

    // 输出结果
    for (const auto& elem : C) {
        std::cout << "C(" << elem.row << ", " << elem.col << ") = " << elem.value << std::endl;
    }

    return 0;
}


2) CSR实现spgemm

#include <iostream>
#include <vector>
#include <unordered_map>
#include <stdexcept>
using namespace std;

// CSR格式的稀疏矩阵
struct CSRMatrix {
	std::vector<int> row_ptr;  // 每行的起始索引
	std::vector<int> col_idx;  // 非零元素的列索引
	std::vector<double> values;  // 非零元素的值
	int rows, cols;  // 矩阵的行数和列数
};

// SpGEMM算法实现
CSRMatrix spgemm(const CSRMatrix& A, const CSRMatrix& B) {
	// 检查矩阵整体维度是否匹配
	if (A.cols != B.rows) {
		throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
	}

	CSRMatrix C;
	//初始化C矩阵的rows和cols
	C.rows = A.rows;
	C.cols = B.cols;
	//初始化c.row_ptr, 并将元素值都设为0
	C.row_ptr.resize(C.rows + 1, 0);

	// //定义元素个数为C.rows, 元素类型为字典的向量, 临时存储每行的非零元素
	// 实现了可以自动扩展大小的二级数组的效果
	std::vector<std::unordered_map<int, double>> temp(C.rows);

	// 遍历矩阵A的每一行
	for (int i = 0; i < A.rows; ++i) {
		// 遍历A的当前行的每一个非零元素
		//由于CSR存储格式的特征, row_ptr[i+1]-row_ptr[i]代表第i行元素的个数
		for (int j = A.row_ptr[i]; j < A.row_ptr[i + 1]; ++j) {
			int a_col = A.col_idx[j];  // A的列索引
			double a_val = A.values[j];  // A的值
			//B.row_ptr[a_col]: 找到与A的列对应的矩阵B的行
			for (int k = B.row_ptr[a_col]; k < B.row_ptr[a_col + 1]; ++k) {
				int b_col = B.col_idx[k];  // B的列索引, 与矩阵A中某一行相乘的矩阵B中的某一个列号
				double b_val = B.values[k];  // B的值
				temp[i][b_col] += a_val * b_val;  // 累加结果
			}
		}
	}

	// 将临时存储的结果转换为CSR格式
	for (int i = 0; i < C.rows; ++i) {
		for (const auto& pair : temp[i]) {
			C.col_idx.push_back(pair.first);  // 列索引
			C.values.push_back(pair.second);  // 值
		}
		C.row_ptr[i + 1] = C.col_idx.size();  // 更新行指针
	}

	return C;
}

int main() {
	
	// 示例矩阵A的初始化
	CSRMatrix A = {
		{0, 2, 4},  // 行指针
		{0, 1, 0, 2},  // 列索引
		{1.0, 2.0, 3.0, 4.0},  // 值
		2, 3  // 行数和列数
	};

	// 示例矩阵B的初始化
	CSRMatrix B = {
		{0, 1, 3, 4},  // 行指针
		{0, 1, 2, 2},  // 列索引
		{5.0, 6.0, 7.0, 8.0},  // 值
		3, 3  // 行数和列数
	};

	// 计算矩阵C = A * B
	CSRMatrix C = spgemm(A, B);

	// 输出结果矩阵C
	std::cout << "C.row_ptr: ";
	for (int val : C.row_ptr) std::cout << val << " ";
	std::cout << "\nC.col_idx: ";
	for (int val : C.col_idx) std::cout << val << " ";
	std::cout << "\nC.values: ";
	for (double val : C.values) std::cout << val << " ";
	std::cout << std::endl;
	return 0;
}

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值