C++元编程实现简单矩阵运算

目录

基础数据类型实现

基础算法实现

基础遍历算法实现

用户函数实现

矩阵算法试用


运行时编程通过简单的循环可以通过一个简单的函数实现矩阵的运算,那么使用元编程实现矩阵乘法的意义何在呢?元编程如果强行要说有什么优势的话就是其运行效率可能会有所提高、代码意义明确(至少对于掌握元编程的人来说是这样的)且无需很多异常判断,当然代价是程序实现的困难、执行程序的扩大以及编译时间的延长。元编程最大的好处是可以在编译期判断出固定大小矩阵计算的错误,减少BUG在运行期出现的风险。总之这种实现方法也未尝不可一试。

基础数据类型实现

首先实现矩阵类模板,这个很简单,就是矩阵数据的设置、存取。

/* 内存管理 */
template<typename val_t>
struct mat_m 
{
	val_t* p;
	mat_m(const int& sz):p(nullptr)
	{
		p = (val_t*)malloc(sz* sizeof(val_t));
	}
	~mat_m() 
	{
		if (p)
			free(p);
	}
	val_t& get(const int& len_1d, const int& i_1d_idx, const int& i_2d_idx) 
	{
		return p[i_2d_idx + len_1d * i_1d_idx];
	}
};
/* 矩阵模板 */
template<int row_num, int col_num, typename val_t = double>
struct mat
{
	typedef val_t vt;
	static constexpr int r = row_num;
	static constexpr int c = col_num;
	std::shared_ptr<mat_m<val_t>> pval;
	mat()
	{
		pval = std::make_shared<mat_m<val_t>>(row_num * col_num);
	}
	mat(const mat<row_num, col_num, val_t>& other) :pval(other.pval)
	{
	}
	mat(const val_t&& v)
	{
		pval = std::make_shared<mat_m<val_t>>(row_num * col_num);
		for (int i = 0; i < row_num; ++i) 
		{
			for (int j = 0; j < col_num; ++j) 
			{
				pval->get(col_num, i,j) = v;
			}
		}
	}
	mat(const std::initializer_list<val_t>& lst)
	{
		pval = std::make_shared<mat_m<val_t>>(row_num * col_num);
		auto itr = lst.begin();
		for (int i = 0 ; i < row_num; ++i)
		{
			for (int j = 0; j < col_num; ++j)
			{
				pval->get(col_num, i, j) = *itr;
				itr++;
				if (itr == lst.end())return;
			}
		}
	}
	
	val_t& get(const int& i_row, const int& i_col)
	{
		return pval->get(col_num, i_row, i_col);
	}

	val_t get(const int& i_row, const int& i_col) const
	{
		return pval->get(col_num, i_row, i_col);
	}
	
	mat<col_num, row_num, val_t> t() 
	{
		mat<col_num, row_num, val_t> ret;
		ret.pval = pval;
		return ret;
	}

	void print() 
	{
		std::cout << "[" << std::endl;
		for (int i = 0; i < row_num; ++i) 
		{
			std::cout << std::setw(3) << "[";
			for (int j = 0; j < col_num; ++j) 
			{
				std::cout << std::setw(6) << pval->get(col_num, i, j);
			}
			std::cout << std::setw(3) << "]" << std::endl;
		}
		std::cout << "]" << std::endl;
	}
};

基础算法实现

我们对于点乘的算法都清楚,要获取目标矩阵的一个元素,首先根据元素所在行获取左矩阵的行向量,再根据元素列获取右矩阵的列向量。对行列元素进行同步迭代。对每对元素相乘并加和。众所周知,元编程中的循环是通过递归实现的。那么我们就使用递归实现元素的乘积和,然后用递归实现对应行列的同步迭代,代码如下(这里使用了if constexpr是c++17才具备的特性,如果版本不够可以使用模板特化来实现,当然,代码会看上去更乱一些):

template<int r1, int c1, int r2, int c2, typename imatt1, typename imatt2, typename vt = double>
inline vt n_dot(const imatt1& mt1, const imatt2& mt2)
{
	static_assert(c1 == r2, "[matrix dot error]\tleft matrix column number do not match right matrix's row number.");
	if constexpr (c1 != 0 || r2 != 0)
	{
		return mt1.get(r1, c1) * mt2.get(r2, c2) + n_dot<r1, c1 - 1, r2 - 1, c2>(mt1, mt2);
	}
	if constexpr (c1 == 0 && r2 == 0)
	{
		return mt1.get(r1, c1) * mt2.get(r2, c2);
	}
}

template<int r, int c, typename imatt1, typename imatt2, typename vt = double>
inline vt v_dot(const imatt1& mt1, const imatt2& mt2) 
{
	return n_dot<r, mt1.c - 1, mt2.r - 1, c>(mt1, mt2);
}

上述代码通过调用v_dot可以获取目标矩阵[r,c]位置的值。我们还需要对目标矩阵的行列依次迭代以获取每个位置的值。同理,我们需要使用递归来实现行的迭代,然后再使用递归实现列的迭代(由此可见,能够熟练使用递归算法对于元编程是一种核心技能)。首先先实现一些基础运算,用于计算输出矩阵对应位置的值,代码如下:


template<int r1, int c1, int r2, int c2, typename imatt1, typename imatt2, typename vt = double>
inline vt n_dot(const imatt1& mt1, const imatt2& mt2)
{
	static_assert(c1 == r2, "[matrix dot error]\tleft matrix column number do not match right matrix's row number.");
	if constexpr (c1 != 0 || r2 != 0)
	{
		return mt1.get(r1, c1) * mt2.get(r2, c2) + n_dot<r1, c1 - 1, r2 - 1, c2>(mt1, mt2);
	}
	if constexpr (c1 == 0 && r2 == 0)
	{
		return mt1.get(r1, c1) * mt2.get(r2, c2);
	}
}

template<int r, int c>
class v_dot
{
public:
	template<typename imatt1, typename imatt2, typename vt = double>
	static vt cal(const imatt1& mt1, const imatt2& mt2)
	{
		return n_dot<r, mt1.c - 1, mt2.r - 1, c>(mt1, mt2);
	}
};

template<int r, int c>
class v_add
{
public:
	template<typename imatt1, typename imatt2, typename vt = double>
	static vt cal(const imatt1& mt1, const imatt2& mt2)
	{
		return mt1.get(r, c) + mt2.get(r, c);
	}
};

template<int r, int c>
class n_add
{
public:
	template<typename imatt2, typename vt = double>
	static vt cal(const vt& mt1, const imatt2& mt2)
	{
		return mt1 + mt2.get(r, c);
	}
};

template<int r, int c>
class n_mul
{
public:
	template<typename imatt2, typename vt = double>
	static vt cal(const vt& mt1, const imatt2& mt2)
	{
		return mt1 * mt2.get(r, c);
	}
};

template<typename val_t = double>
val_t sigmoid(const val_t& v)
{
	return 1. / (1. + exp(-1. * v));
}

template<int r, int c>
class sigmoidc
{
public:
	template<typename imatt, typename vt = double>
	static vt cal(const imatt& mt) 
	{
		return sigmoid(mt.get(r, c));
	}
};

基础遍历算法实现

接下来,实现一下对输出矩阵的遍历操作,代码如下:

template<int r, int c, template<int, int> class op, typename omatt, typename... imatts>
inline void row_loop(omatt& omt, const imatts&...imts)
{
	omt.get(r, c) = op<r, c>::cal(imts...);
	if constexpr (r != 0) 
	{
		row_loop<r - 1, c, op>(omt, imts...);
	}
}


template<int c, template<int, int> class op, typename omatt, typename...imatts>
inline void col_loop(omatt& omt, const imatts&...imts)
{
	row_loop<omt.r - 1, c, op>(omt, imts...);
	if constexpr (c != 0)
	{
		col_loop<c - 1, op>(omt, imts...);
	}
}

用户函数实现

以上都是内部函数,用于实现内部逻辑,接下来实现的是各个具体被使用的函数。

template<int row_num1, int col_num1, int row_num2, int col_num2, typename val_t = double>
mat<row_num1, col_num2, val_t> dot(const mat<row_num1, col_num1, val_t>& mt1, const mat<row_num2, col_num2, val_t>& mt2)
{
	
	using omatt = mat<row_num1, col_num2, val_t>;
	using imatt1 = mat<row_num1, col_num1, val_t>;
	using imatt2 = mat<row_num2, col_num2, val_t>;
	omatt mt_ret;
	col_loop<col_num2 - 1, v_dot>(mt_ret, mt1, mt2);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> add(const mat<row_num, col_num, val_t>& mt1, const mat<row_num, col_num, val_t>& mt2) 
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<mt2.c - 1, v_add>(mt_ret, mt1, mt2);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> add(const val_t& v, const mat<row_num, col_num, val_t>& mt)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_add>(mt_ret, v, mt);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> add(const mat<row_num, col_num, val_t>& mt, const val_t& v)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_add>(mt_ret, v, mt);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> mul_n(const val_t& v, const mat<row_num, col_num, val_t>& mt)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_mul>(mt_ret, v, mt);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> mul_n(const mat<row_num, col_num, val_t>& mt, const val_t& v)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_mul>(mt_ret, v, mt);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> sigmoidm(const mat<row_num, col_num, val_t>& mt)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, sigmoidc>(mt_ret, mt);
	return mt_ret;
}

矩阵算法试用

至此我们已经实现了矩阵的乘法运算。实际上的核心算法代码是数个函数构成,每个函数都意义明确。下面我们来试用一下这个元编程矩阵运算。试验代码如下:

int main(int argc, char** argv)
{
	mat<3, 2> mt1 = { 1.1,2.2,3.,4.,5.,6.,7.,8.,9. };
	mat<2, 3> mt2({ 9., 8., 7., 6., 5., 4., 3., 2., 1. });

	auto mt3 = mt2.t();               // 矩阵转置

	mt1.print();
	mt2.print();
	mt3.print();
	
	auto omt = dot(mt1, mt2);        // 矩阵点乘
	omt.print();

	auto omt2 = add(mt1, mt3);        // 矩阵加法
	omt2.print();

	auto omt3 = add(1.0, omt2);        // 常数加矩阵
	omt3.print();
	auto omt4 = add(omt2, 1.0);        // 常数加矩阵
	omt4.print();

	auto omt5 = mul_n(omt2, 2.0);    // 常数乘矩阵
	omt5.print();
	auto omt6 = mul_n(0.1, omt2);    // 常数乘矩阵
	omt6.print();

	auto omt7 = sigmoidm(omt6);        // 矩阵各元素sigmoid运算推广
	omt7.print();

    return 0;
}

结果表明,这个是可以正确运算出结果的。以下是运行结果:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

腾昵猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值