C++ 方矩阵乘法 + Strassen矩阵

这篇博客介绍了如何在C++中实现矩阵乘法,包括传统的n^3复杂度算法和Strassen算法。Strassen算法通过分治策略将复杂度降低到n^lg7,在矩阵规模较大时体现出优势。文章详细阐述了矩阵类的设计,如使用Matrix和MatrixRef,并讨论了处理非2的幂次矩阵的方法。尽管在小规模矩阵运算中传统算法更快,但Strassen算法在大规模矩阵乘法时表现更优,并且在300*300以上的尺寸可能导致栈溢出问题。
摘要由CSDN通过智能技术生成

  这几天看算法导论,看到矩阵一章,就实现了一下。

下面是普通的矩阵乘法,复杂度为:n^3。

template<unsigned M,unsigned N, unsigned Q>
void Square_matrix_multiply(int(&A)[M][N], int(&B)[N][Q], int(&C)[M][Q]) {                 
	for (size_t i = 0;i != M;++i) {
		for (size_t j = 0;j != Q;++j) {
			C[i][j] = 0;
			for (size_t n = 0;n != N;++n) {
				C[i][j] += A[i][n] * B[n][j];
			}
		}
	}
}

函数接受三个二维数组,A * B得到的矩阵赋值给C。

下面是分治策略的算法。

template<typename T>
Matrix Square_max_matrix_multiply_recursive(const T &A, const T &B) {
	size_t n = A.rows();
	Matrix C(n, n);
	if (n == 1)
		return C = A.get()*B.get();
	else {
		MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);    // 使用一个类MatirxRef
		MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);    // 含有三个size_t类型。其中两个实现坐标,一个指明矩阵长度
		MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);    // 进行分割
		C_11 = Square_max_matrix_multiply_recursive(A_11, B_11) + Square_max_matrix_multiply_recursive(A_12, B_21);  // Matrix::operator+;
		C_12 = Square_max_matrix_multiply_recursive(A_11, B_12) + Square_max_matrix_multiply_recursive(A_12, B_22);  // MatrixRef::operator=;
		C_21 = Square_max_matrix_multiply_recursive(A_21, B_11) + Square_max_matrix_multiply_recursive(A_22, B_21);
		C_22 = Square_max_matrix_multiply_recursive(A_21, B_12) + Square_max_matrix_multiply_recursive(A_22, B_22);
	}
	return C;
}

矩阵实现了一个Matrix类(具体实现在最下面),有一个构造函数:接受两个size_t值l、r,生成l*r大小值全为0的矩阵。

Matrix::Matrix(size_t l, size_t r) : hight(l), width(r), data(make_shared<vector<int>>()) {
	data->resize(l*r);
}

其中hight为矩阵行高,width为列宽,data为shared_ptr,矩阵用vector实现。

A.rows()返回A的width长度(即方矩阵的边长),Matrix(n,n)创建一个矩阵。

size_t rows() const {
		return width;
	}

如果n==1,通过Matrix的get函数返回第一个元素,也就是唯一的一个元素。

int Matrix::get() const {
	return (*data)[0];
}

为了不复制矩阵元素(如果可以复制矩阵元素的话,会简单很多),另实现了一个MatrixRef,其含有:两个size_t数据成员(实现坐标点)、一个size_t数据成员(实现矩阵长度)、一个weak_ptr(指向vector<int>)。

MatrixRef含有两个构造函数:一个接受Matrix加两个size_t;一个接受MatrixRef加两个size_t。都是为了指明引用范围。

MatrixRef::MatrixRef(const Matrix &m, size_t line, size_t row) : wptr(m.data), 
      hight_startptr(line), width_startptr(row), length(m.rows() / 2) { }
MatrixRef::MatrixRef(const MatrixRef &mref, size_t line, size_t row) : wptr(mref.wptr),                 
      hight_startptr(mref.hight_startptr + line),
      width_startptr(mref.width_startptr + row), length(mref.rows() / 2) { }

wptr用data或wptr初始化,避免拷贝。length为rows()的返回值除以2,因为是分割为4个矩阵,行列各除以2。

要注意:接受MatrixRef的坐标要加上之前的坐标。

MatrixRef也有一个rows成员函数,为了递归调用。

size_t rows() const {
		return length;
	}

Square_max_matrix_multiply_recursive函数返回一个Matrix,Matrix实现了operator+,但是行列必须相等。

Matrix& Matrix::operator+=(const Matrix &rhs) {
	if (hight == rhs.hight && width == rhs.width) {
		for (size_t i = 0;i != size();++i)
			(*data)[i] += (*rhs.data)[i];
	}
	else
		throw std::logic_error("Not Matched");
	return *this;
}
Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
	Matrix m(lhs);
	return m += rhs;
}

MatrixRef实现了一个operato

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值