C++元编程——BP神经网络编译期优化

通过C++元编程技术,将BP神经网络中矩阵get函数的计算过程移至编译期,以提高矩阵运算速度并实现编译期越界检查。实测能提升约40%的运行效率。
摘要由CSDN通过智能技术生成

之前的BP神经网络中,矩阵的get函数要用乘法和加法计算出数据所处位置,而实际上,大多数情况下程序获取的是固定位置的数据,所以可以想到使用编译期的计算方法,这样在进行矩阵运算时候可以稍微加快运算速度,而且还可以进行编译期越界检测。word is cheap,show me the code。

首先,测试函数和上一篇博客的一样:https://blog.csdn.net/Dr_Jack/article/details/127776013

运行速度经过测试大概能提升40%。下面展示mat.hpp代码:

#ifndef _MAT_HPP_
#define _MAT_HPP_
#include <climits>

#include <map>
#include <boost/pool/pool.hpp>

template<int i_size, typename val_t>
struct mat_m
{
	static boost::pool<> s_pool;
	val_t* p;
	mat_m() :p(nullptr)
	{
		//p = (val_t*)malloc(sz * sizeof(val_t));
		p = (val_t*)(s_pool.malloc());
		for (int i = 0; i < i_size; ++i)
		{
			p[i] = 0;
		}
	}
	~mat_m()
	{
		if (p)
		{
			//free(p);
			s_pool.free(p);
		}
	}
	val_t& get(const int& len_1d, const int& i_1d_idx, const int& i_2d_idx)
	{
		val_t& ret = p[i_2d_idx + len_1d * i_1d_idx];
		if (ret != 0.000 && abs(ret) < (DBL_MIN))
		{
			p[i_2d_idx + len_1d * i_1d_idx] = (DBL_MIN * (ret < 0 ? -1. : 1.));
		}
		return ret;
	}

	val_t max_abs() const
	{
		double d = -1*DBL_MAX;
		for (int i = 0; i < i_size; ++i) 
		{
			d = d < abs(p[i]) ? abs(p[i]) : d;
		}
		return d;
	}

	template<int len_1d, int i_1d_idx, int i_2d_idx>
	inline val_t& get_val()
	{
		static_assert((i_2d_idx + len_1d * i_1d_idx) < i_size, "ERROR:mat_m over flow!!!");
		return p[i_2d_idx + len_1d * i_1d_idx];
	}

	template<int len_1d, int i_1d_idx, int i_2d_idx>
	inline val_t get_val() const
	{
		return p[i_2d_idx + len_1d * i_1d_idx];
	}
};

template<int i_size, typename val_t>
boost::pool<> mat_m<i_size, val_t>::s_pool = boost::pool<>(i_size * sizeof(val_t));

template<int row_num, int col_num, typename val_t = double>
struct mat
{
	typedef val_t vt;
	static constexpr int r = row_num;
	static constexpr int c = col_num;
	using mat_m_t = mat_m<row_num * col_num, val_t>;
	std::shared_ptr<mat_m_t> pval;
	bool b_t;
	mat():b_t(false)
	{
		pval = std::make_shared<mat_m_t>();
	}
	mat(const mat<row_num, col_num, val_t>& other) :pval(other.pval), b_t(other.b_t)
	{
	}
	mat(const val_t&& v):b_t(false)
	{
		pval = std::make_shared<mat_m_t>();
		for (int i = 0; i < row_num; ++i)
		{
			for (int j = 0; j < col_num; ++j)
			{
				pval->get(col_num, i, j) = v;
			}
		}
	}
	mat(const std::initializer_list<val_t>& lst):b_t(false)
	{
		pval = std::make_shared<mat_m_t>();
		auto itr = lst.begin();
		for (int i = 0; i < row_num; ++i)
		{
			for (int j = 0; j < col_num; ++j)
			{
				pval->get(col_num, i, j) = *itr;
				itr++;
				if (itr == lst.end())return;
			}
		}
	}

	val_t& get(const int& i_row, const int& i_col)
	{
		if (!b_t)
			return pval->get(col_num, i_row, i_col);
		else
			return pval->get(row_num, i_col, i_row);
	}

	val_t get(const int& i_row, const int& i_col) const
	{
		if (!b_t)
			return pval->get(col_num, i_row, i_col);
		else
			return pval->get(row_num, i_col, i_row);
	}

	template<int i_1d_idx, int i_2d_idx>
	inline val_t& get_val()
	{
		if (!b_t)
			return pval->get_val<col_num, i_1d_idx, i_2d_idx>();
		else
			return pval->get_val<row_num, i_2d_idx, i_1d_idx>();
	}

	template<int i_1d_idx, int i_2d_idx>
	inline val_t get_val() const
	{
		static_assert(i_1d_idx < row_num && i_2d_idx < col_num, "ERROR:访问越界!!!!!");
		if (!b_t)
			return pval->get_val<col_num, i_1d_idx, i_2d_idx>();
		else
			return pval->get_val<row_num, i_2d_idx, i_1d_idx>();
	}

	mat<col_num, row_num, val_t> t()
	{
		mat<col_num, row_num, val_t> ret;
		ret.pval = pval;
		ret.b_t = !b_t;
		return ret;
	}

	val_t max_abs() const
	{
		return pval->max_abs();
	}

	void print()
	{
		std::cout << "[" << std::endl;
		for (int i = 0; i < row_num; ++i)
		{
			std::cout << std::setw(3) << "[";
			for (int j = 0; j < col_num; ++j)
			{
				std::cout << (j != 0 ? "," : "") << std::setw(10) << get(i, j);
			}
			std::cout << std::setw(3) << "]" << std::endl;
		}
		std::cout << "]" << std::endl;
	}
};

#endif

base_logic.hpp

#ifndef _BASE_LOGIC_HPP_
#define _BASE_LOGIC_HPP_

/* 矩阵迭代运算 */

template<int r, int c, template<int, int> class op, typename omatt, typename... imatts>
inline void row_loop(omatt& omt, const imatts&...imts)
{
	omt.get_val<r, c>() = op<r, c>
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

腾昵猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值