C++元编程——四维矩阵简单运算实现

基于原来的矩阵,进行了魔改,形成了四维矩阵的点积运算,效果拔群,对于矩阵的运算有效。老规矩,先上测试代码:

#include "mat.hpp"

int main(int argc, char** argv)
{
	mat<3, 1, mat<2, 2, double > > m3d{1,2,3};
	m3d.print();
	mat<1, 3, mat<2, 2, double> > m3d2{ {1,2,3,4},{2,3,4,5},{3,4,5,6} };
	m3d2.print();
	auto k = m3d.dot(m3d2);
	k.print();
	return 0;
}

四维矩阵是一个二维矩阵,其每个元素都是一个二维矩阵。点积运算规则和二维一致,只是每个元素的运算换做了矩阵对应元素的对乘,对加,仅此而已。但是这却完全体现出了C++元编程的牛X之处。如果是运行时,你就要重新定义一个类型,然后定义这个类型的加减乘除。但是元编程不用,如果推导合理,但是你必须把普通标量运算实现出矩阵版本。运行结果如下:

结果也是NICE的,正确算出了这两个四维矩阵的点积。下面是修改的代码。

base_function.hpp对原来的点积运算参数推导进行了细化:

#ifndef _BASE_FUNCTION_HPP_
#define _BASE_FUNCTION_HPP_

#include "base_logic.hpp"

template<typename func_t>
auto derivative(func_t&& f, const decltype(f(0))& v)
{
	constexpr double SMALL_VAL = 1e-11;
	return (f(v + SMALL_VAL) - f(v - SMALL_VAL)) / (2. * SMALL_VAL);
}

/* 点乘运算 */
template<int r1, int c1, int r2, int c2, int i1, int i2, int i3, typename vt = double>
inline vt n_dot(const mat<i1, i2, vt>& mt1, const mat<i2, i3, vt>& mt2)
{
	static_assert(c1 == r2, "[matrix dot error]\tleft matrix column number do not match right matrix's row number.");
	if constexpr (c1 != 0 || r2 != 0)
	{
		return mt1.get_val<r1, c1>() * mt2.get_val<r2, c2>() + n_dot<r1, c1 - 1, r2 - 1, c2>(mt1, mt2);
	}
	if constexpr (c1 == 0 && r2 == 0)
	{
		return mt1.get_val<r1, c1>() * mt2.get_val<r2, c2>();
	}
}

template<int r, int c>
class v_dot
{
public:
	template<int i1, int i2, int i3, typename vt = double>
	static vt cal(const mat<i1, i2, vt>& mt1, const mat<i2, i3, vt>& mt2)
	{
		return n_dot<r, mt1.c - 1, mt2.r - 1, c>(mt1, mt2);
	}
};

template<int row_num1, int col_num1, int row_num2, int col_num2, typename val_t = double>
mat<row_num1, col_num2, val_t> dot(const mat<row_num1, col_num1, val_t>& mt1, const mat<row_num2, col_num2, val_t>& mt2)
{

	using omatt = mat<row_num1, col_num2, val_t>;
	using imatt1 = mat<row_num1, col_num1, val_t>;
	using imatt2 = mat<row_num2, col_num2, val_t>;
	omatt mt_ret;
	col_loop<col_num2 - 1, v_dot>(mt_ret, mt1, mt2);
	return mt_ret;
}

/* 加法运算 */
template<int r, int c>
class v_add
{
public:
	template<typename imatt1, typename imatt2, typename vt = double>
	static vt cal(const imatt1& mt1, const imatt2& mt2)
	{
		return mt1.get_val<r, c>() + mt2.get_val<r, c>();
	}
};

template<int r, int c>
class n_add
{
public:
	template<typename imatt2, typename vt = double>
	static vt cal(const vt& mt1, const imatt2& mt2)
	{
		return mt1 + mt2.get_val<r, c>();
	}
};

template<int r, int c>
struct c_add
{
	template<int row_num, int col_num, typename vt>
	static vt cal(const mat<row_num, col_num, vt>& mt, const mat<row_num, 1, vt>& v)
	{
		return mt.get_val<r, c>() + v.get_val<r, 0>();
	}
};

template<int r, int c>
struct r_add
{
	template<int row_num, int col_num, typename vt>
	static vt cal(const mat<row_num, col_num, vt>& mt, const mat<1, col_num, vt>& v)
	{
		return mt.get_val<r, c>() + v.get_val<0, c>();
	}
};

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> operator+(const mat<row_num, col_num, val_t>& mt1, const mat<row_num, col_num, val_t>& mt2)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, v_add>(mt_ret, mt1, mt2);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> operator+(const val_t& v, const mat<row_num, col_num, val_t>& mt)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_add>(mt_ret, v, mt);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> operator+(const mat<row_num, col_num, val_t>& mt, const val_t& v)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_add>(mt_ret, v, mt);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t_other, typename val_t = double>
mat<row_num, col_num, val_t> operator+(const mat<row_num, col_num, val_t>& mt, const val_t_other& v)
{
	return mt + static_cast<val_t>(v);
}

template<int row_num, int col_num, typename val_t_other, typename val_t = double>
mat<row_num, col_num, val_t> operator+(const val_t_other& v, const mat<row_num, col_num, val_t>& mt)
{
	return mt + static_cast<val_t>(v);
}

template<int row_num, int col_num, typename val_t = double>
void spread_add(mat<row_num, col_num, val_t>& mt_ret, const mat<row_num, col_num, val_t>& mt, const mat<row_num, 1, val_t>& v)
{
	col_loop<col_num - 1, c_add>(mt_ret, mt, v);
}

template<int row_num, int col_num, typename val_t = double>
void spread_add(mat<row_num, col_num, val_t>& mt_ret, const mat<row_num, col_num, val_t>& mt, const mat<1, col_num, val_t>& v)
{
	col_loop<col_num - 1, r_add>(mt_ret, mt, v);
}

template<int row_num, int col_num, typename val_t = double>
void spread_add(mat<row_num, col_num, val_t>& mt_ret, const mat<row_num, col_num, val_t>& mt, const mat<1, 1, val_t>& v)
{
	col_loop<col_num - 1, n_add>(mt_ret, v.get_val<0, 0>(), mt);
}

template<int row_num, int col_num, typename val_t = double>
void spread_add(mat<row_num, col_num, val_t>& mt_ret, const mat<row_num, 1, val_t>& v, const mat<row_num, col_num, val_t>& mt)
{
	col_loop<col_num - 1, c_add>(mt_ret, mt, v);
}

template<int row_num, int col_num, typename val_t = double>
void spread_add(mat<row_num, col_num, val_t>& mt_ret, const mat<1, col_num, val_t>& v, const mat<row_num, col_num, val_t>& mt)
{
	col_loop<col_num - 1, r_add>(mt_ret, mt, v);
}

template<int row_num, int col_num, typename val_t = double>
void spread_add(mat<row_num, col_num, val_t>& mt_ret, const mat<1, 1, val_t>& v, const mat<row_num, col_num, val_t>& mt)
{
	col_loop<col_num - 1, n_add>(mt_ret, v.get_val<0, 0>(), mt);
}

template<typename val_t = double>
void spread_add(mat<1, 1, val_t>& mt_ret, const mat<1, 1, val_t>& v, const mat<1, 1, val_t>& mt)
{
	col_loop<0, n_add>(mt_ret, v.get_val<0, 0>(), mt.get_val<0, 0>);
}

/* 减法运算 */
template<int r, int c>
class v_minus
{
public:
	template<typename imatt1, typename imatt2, typename vt = double>
	static vt cal(const imatt1& mt1, const imatt2& mt2)
	{
		return mt1.get_val<r, c>() - mt2.get_val<r, c>();
	}
};

template<int r, int c>
class n_minus
{
public:
	template<int mat_r, int mat_c, typename vt = double>
	static vt cal(const vt& v, const mat<mat_r, mat_c, vt>& mt2)
	{
		return v - mt2.get_val<r, c>();
	}

	template<int mat_r, int mat_c, typename vt = double>
	static vt cal(const mat<mat_r, mat_c, vt>& mt2, const vt& v)
	{
		return mt2.get_val<r, c>() - v;
	}
};

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> operator-(const mat<row_num, col_num, val_t>& mt1, const mat<row_num, col_num, val_t>& mt2)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, v_minus>(mt_ret, mt1, mt2);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> operator-(const val_t& v, const mat<row_num, col_num, val_t>& mt)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_minus>(mt_ret, v, mt);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> operator-(const mat<row_num, col_num, val_t>& mt, const val_t& v)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_minus>(mt_ret, mt, v);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t_other, typename val_t = double>
mat<row_num, col_num, val_t> operator-(const mat<row_num, col_num, val_t>& mt, const val_t_other& v)
{
	return mt - static_cast<val_t>(v);
}

template<int row_num, int col_num, typename val_t_other, typename val_t = double>
mat<row_num, col_num, val_t> operator-(const val_t_other& v, const mat<row_num, col_num, val_t>& mt)
{
	return (static_cast<val_t>(v) - mt);
}

/* 乘法运算 */
template<int r, int c>
class n_mul
{
public:
	template<typename imatt2, typename vt = double>
	static vt cal(const vt& mt1, const imatt2& mt2)
	{
		return mt1 * mt2.get_val<r, c>();
	}
};

template<int r, int c>
class v_mul
{
public:
	template<typename imatt1, typename imatt2, typename vt = double>
	static vt cal(const imatt1& mt1, const imatt2& mt2)
	{
		return mt1.get_val<r, c>() * mt2.get_val<r, c>();
	}
};

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> operator*(const val_t& v, const mat<row_num, col_num, val_t>& mt)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_mul>(mt_ret, v, mt);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> operator*(const mat<row_num, col_num, val_t>& mt, const val_t& v)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_mul>(mt_ret, v, mt);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> operator*(const mat<row_num, col_num, val_t>& mt1, const mat<row_num, col_num, val_t>& mt2)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, v_mul>(mt_ret, mt1, mt2);
	return mt_ret;
}

/* 除法 */
template<int r, int c>
class n_div
{
public:
	template<int row_num, int col_num, typename vt = double>
	static vt cal(const mat<row_num, col_num, vt>& mt, const vt& v)
	{
		return mt.get_val<r, c>() / v;
	}

	template<int row_num, int col_num, typename vt = double>
	static vt cal(const vt& v, const mat<row_num, col_num, vt>& mt)
	{
		return v / mt.get_val<r, c>();
	}
};

template<int r, int c>
class v_div
{
public:
	template<int row_num, int col_num, typename vt = double>
	static vt cal(const mat<row_num, col_num, vt>& mt1, const mat<row_num, col_num, vt>& mt2)
	{
		return mt1.get_val<r, c>() / mt2.get_val<r, c>();
	}
};

template<int row_num, int col_num, typename val_t>
mat<row_num, col_num, val_t> operator/(const mat<row_num, col_num, val_t>& mt, const val_t& v)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_div>(mt_ret, mt, v);
	return mt_ret;
}

template<int row_num, int col_num, typename val_t>
mat<row_num, col_num, val_t> operator/(const val_t& v, const mat<row_num, col_num, val_t>& mt)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_div>(mt_ret, v, mt);
	return mt_ret;
}

template<int row_num, int col_num, typename vt>
mat<row_num, col_num, vt> operator/(const mat<row_num, col_num, vt>& mt1, const mat<row_num, col_num, vt>& mt2)
{
	using omatt = mat<row_num, col_num, vt>;
	omatt mt_ret;
	col_loop<col_num - 1, v_div>(mt_ret, mt1, mt2);
	return mt_ret;
}

template<int r, int c>
class n_sqrt
{
public:
	template<typename imatt, typename vt = double>
	static vt cal(const imatt& mt)
	{
		return sqrtl(mt.get_val<r, c>());
	}
};

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> sqrtm(const mat<row_num, col_num, val_t>& mt)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_sqrt>(mt_ret, mt);
	return mt_ret;
}

/* exp运算 */
template<int r, int c>
struct n_exp
{
	template<typename imatt>
	static typename imatt::type cal(const imatt& mt)
	{
		return exp(mt.get_val<r, c>());
	}
};

template<int row_num, int col_num, typename val_t = double>
mat<row_num, col_num, val_t> expm(const mat<row_num, col_num, val_t>& mt)
{
	using omatt = mat<row_num, col_num, val_t>;
	omatt mt_ret;
	col_loop<col_num - 1, n_exp>(mt_ret, mt);
	return mt_ret;
}

/* 卷积运算 */
template<int row_base, int col_base, int row_delta, int col_delta, typename imat_origin, typename imat_tpl>
inline auto col_loop_mul(const imat_origin& mt_origin, const imat_tpl& mt_tpl)
{
	if constexpr (col_delta != 0)
	{
		return mt_origin.get_val<row_base + row_delta, col_base + col_delta>() * mt_tpl.get_val<row_delta, col_delta>()
			+ col_loop_mul<row_base, col_base, row_delta, col_delta - 1, imat_origin, imat_tpl>(mt_origin, mt_tpl);
	}
	if constexpr (col_delta == 0)
	{
		return mt_origin.get_val<row_base + row_delta, col_base + col_delta>() * mt_tpl.get_val<row_delta, col_delta>();
	}
}

template<int row_base, int col_base, int row_delta, int col_delta, typename imat_origin, typename imat_tpl>
inline auto row_loop_add(const imat_origin& mt_origin, const imat_tpl& mt_tpl)
{
	if constexpr (row_delta != 0)
	{
		return col_loop_mul<row_base, col_base, row_delta, col_delta>(mt_origin, mt_tpl)
			+ col_loop_mul<row_base, col_base, row_delta - 1, col_delta>(mt_origin, mt_tpl);
	}
	if constexpr (row_delta == 0)
	{
		return col_loop_mul<row_base, col_base, row_delta, col_delta>(mt_origin, mt_tpl);
	}
}

template<int r, int c>
struct v_inner_conv
{
	template<typename imat_origin_t, typename imat_tpl_t, typename val_t = double>
	inline static auto cal(const imat_origin_t& mt_origin, const imat_tpl_t& mt_tpl)
	{
		return row_loop_add<r, c, imat_tpl_t::r - 1, imat_tpl_t::c - 1>(mt_origin, mt_tpl);
	}
};

constexpr int get_step_inner_size(int i_origin, int i_tpl, int i_step)
{
	return (i_origin - i_tpl) / i_step + 1;
}

constexpr int get_pad_size(int i_origin, int i_tpl, int i_step)
{
	return (((i_origin - i_tpl) / i_step) + (((i_origin - i_tpl) % i_step) == 0 ? 0 : 1)) * i_step - (i_origin - i_tpl);
}

constexpr int get_ceil_div(int i_origin, int i_tpl)
{
	return (i_origin / i_tpl + ((i_origin % i_tpl) == 0 ? 0 : 1));
}

template<int input_row, int intput_col, int tpl_row, int tpl_col, int row_step, int col_step>
struct pad_size_t
{
	static constexpr int top = get_pad_size(input_row, tpl_row, row_step) / 2;
	static constexpr int left = get_pad_size(intput_col, tpl_col, col_step) / 2;
	static constexpr int right = get_pad_size(intput_col, tpl_col, col_step) - left;
	static constexpr int bottom = get_pad_size(input_row, tpl_row, row_step) - top;
};

template<int row_step, int col_step, int row_num, int col_num, int tpl_row, int tpl_col, typename val_t>
inline mat<get_step_inner_size(row_num, tpl_row, row_step), get_step_inner_size(col_num, tpl_col, col_step), val_t>
inner_conv(const mat<row_num, col_num, val_t>& mt_origin, const mat<tpl_row, tpl_col, val_t>& mt_tpl)
{
	using ret_type = mat<get_step_inner_size(row_num, tpl_row, row_step), get_step_inner_size(col_num, tpl_col, col_step), val_t>;
	ret_type mt_ret;
	col_loop<ret_type::c - 1, v_inner_conv>(mt_ret, mt_origin, mt_tpl);
	return mt_ret;
}

template<typename mat_t, typename ...mat_ts>
struct st_one_col
{
	static constexpr int all_size = (mat_t::r * mat_t::c) + st_one_col<mat_ts...>::all_size;
};

template<typename mat_t>
struct st_one_col<mat_t>
{
	static constexpr int all_size = (mat_t::r * mat_t::c);
};

template<typename mat_t, typename ...mat_ts>
void concat_mat(typename mat_t::type* p, const mat_t& mt, const mat_ts... mts)
{
	constexpr int cpy_size = mat_t::r * mat_t::c;
	memcpy(p, mt.pval->p, cpy_size * sizeof(mat_t::type));
	if constexpr (0 != sizeof...(mat_ts))
		concat_mat(p + cpy_size, mts...);
}

template<typename mat_t, typename ...mat_ts>
mat<st_one_col<mat_t, mat_ts...>::all_size, 1> stretch_one_col(const mat_t& mt, const mat_ts&...mts)
{
	using ret_type = mat<st_one_col<mat_t, mat_ts...>::all_size, 1>;
	ret_type ret;
	concat_mat(ret.pval->p, mt, mts...);
	return ret;
}


template<typename mat_t, typename ...mat_ts>
void split_mat(typename mat_t::type* p, const mat_t& mt, const mat_ts... mts)
{
	constexpr int cpy_size = mat_t::r * mat_t::c;
	memcpy(mt.pval->p, p, cpy_size * sizeof(mat_t::type));
	if constexpr (0 != sizeof...(mat_ts))
		split_mat(p + cpy_size, mts...);
}

template<typename mat_t, typename ...mat_ts>
void split_one_mat(const mat_t& mt, const mat_ts&...mts)
{
	split_mat(mt.pval->p, mts...);
}

#endif

矩阵类型定义mat.hpp,增加了输出函数的运算符,用来打印结果:

#ifndef _MAT_HPP_
#define _MAT_HPP_
#include <climits>

#include <map>
#include <iostream>
#include <iomanip>
#ifdef USE_BOOST
#include <boost/pool/pool.hpp>
#endif

template<int i_size, typename val_t>
struct mat_m
{
#ifdef USE_BOOST
	static boost::pool<> s_pool;
#endif
	val_t* p;
	mat_m() :p(nullptr)
	{
		//p = (val_t*)malloc(sz * sizeof(val_t));
#ifdef USE_BOOST
		p = (val_t*)(s_pool.malloc());
		for (int i = 0; i < i_size; ++i)
		{
			p[i] = 0;
		}
#else
		//p = (val_t*)malloc(i_size * sizeof(val_t));
		p = new val_t[i_size];
#endif

	}
	~mat_m()
	{
		if (p)
		{
#ifdef USE_BOOST
			s_pool.free(p);
#else
			//free(p);
			delete[] p;
#endif
		}
	}
	val_t& get(const int& len_1d, const int& i_1d_idx, const int& i_2d_idx)
	{
		val_t& ret = p[i_2d_idx + len_1d * i_1d_idx];
		/*
		if (ret != 0.000 && abs(ret) < (DBL_MIN))
		{
			p[i_2d_idx + len_1d * i_1d_idx] = static_cast<val_t>(DBL_MIN * (ret < 0 ? -1. : 1.));
		}*/
		return ret;
	}

	val_t max_abs() const
	{
		double d = -1 * DBL_MAX;
		for (int i = 0; i < i_size; ++i)
		{
			d = d < abs(p[i]) ? abs(p[i]) : d;
		}
		return d;
	}

	val_t max() const
	{
		double d = -1 * DBL_MAX;
		for (int i = 0; i < i_size; ++i)
		{
			d = d < (p[i]) ? (p[i]) : d;
		}
		return d;
	}

	val_t sum() const
	{
		double d_sum = 0.;
		for (int i = 0; i < i_size; ++i)
		{
			d_sum += p[i];
		}
		return d_sum;
	}

	template<int len_1d, int i_1d_idx, int i_2d_idx>
	inline val_t& get_val()
	{
		static_assert((i_2d_idx + len_1d * i_1d_idx) < i_size, "ERROR:mat_m over flow!!!");
		return p[i_2d_idx + len_1d * i_1d_idx];
	}

	template<int len_1d, int i_1d_idx, int i_2d_idx>
	inline val_t get_val() const
	{
		return p[i_2d_idx + len_1d * i_1d_idx];
	}
};

#ifdef USE_BOOST
template<int i_size, typename val_t>
boost::pool<> mat_m<i_size, val_t>::s_pool = boost::pool<>(i_size * sizeof(val_t));
#endif

template<int row_num, int col_num, typename val_t = double>
struct mat
{

	//template<int row_num, int col_num, typename val_t = double>
	friend std::ostream& operator<<(std::ostream& cout, const mat<row_num, col_num, val_t>& mt)
	{
		std::cout << "[" ;
		for (int i = 0; i < row_num; ++i)
		{
			std::cout  << "[";
			for (int j = 0; j < col_num; ++j)
			{
				std::cout << (j != 0 ? "," : "") << mt.get(i, j);
			}
			std::cout << "]" ;
		}
		std::cout << "]" ;
		return std::cout;
	}

	using type = val_t;
	typedef val_t vt;
	static constexpr int r = row_num;
	static constexpr int c = col_num;
	using mat_m_t = mat_m<row_num* col_num, val_t>;
	std::shared_ptr<mat_m_t> pval;
	bool b_t;
	mat() :b_t(false)
	{
		pval = std::make_shared<mat_m_t>();
	}
	mat(const mat<row_num, col_num, val_t>& other) :pval(other.pval), b_t(other.b_t)
	{
	}
	mat(const val_t&& v) :b_t(false)
	{
		pval = std::make_shared<mat_m_t>();
		for (int i = 0; i < row_num; ++i)
		{
			for (int j = 0; j < col_num; ++j)
			{
				pval->get(col_num, i, j) = v;
			}
		}
	}
	mat(const std::initializer_list<val_t>& lst) :b_t(false)
	{
		pval = std::make_shared<mat_m_t>();
		auto itr = lst.begin();
		for (int i = 0; i < row_num; ++i)
		{
			for (int j = 0; j < col_num; ++j)
			{
				if (itr == lst.end())return;
				pval->get(col_num, i, j) = *itr;
				itr++;
			}
		}
	}

	val_t& get(const int& i_row, const int& i_col)
	{
		if (!b_t)
			return pval->get(col_num, i_row, i_col);
		else
			return pval->get(row_num, i_col, i_row);
	}

	val_t get(const int& i_row, const int& i_col) const
	{
		if (!b_t)
			return pval->get(col_num, i_row, i_col);
		else
			return pval->get(row_num, i_col, i_row);
	}

	template<int i_1d_idx, int i_2d_idx>
	inline val_t& get_val()
	{
		if (!b_t)
			return pval->get_val<col_num, i_1d_idx, i_2d_idx>();
		else
			return pval->get_val<row_num, i_2d_idx, i_1d_idx>();
	}

	template<int i_1d_idx, int i_2d_idx>
	inline val_t get_val() const
	{
		static_assert(i_1d_idx < row_num&& i_2d_idx < col_num, "ERROR: mat::get_val overflow!!!!!");
		if (!b_t)
			return pval->get_val<col_num, i_1d_idx, i_2d_idx>();
		else
			return pval->get_val<row_num, i_2d_idx, i_1d_idx>();
	}

	mat<col_num, row_num, val_t> t()
	{
		mat<col_num, row_num, val_t> ret;
		ret.pval = pval;
		ret.b_t = !b_t;
		return ret;
	}

	val_t max_abs() const
	{
		return pval->max_abs();
	}

	val_t max() const
	{
		return pval->max();
	}

	val_t sum() const
	{
		return pval->sum();
	}

	void print() const
	{
		std::cout << "[" << std::endl;
		for (int i = 0; i < row_num; ++i)
		{
			std::cout << std::setw(3) << "[";
			for (int j = 0; j < col_num; ++j)
			{
				std::cout << (j != 0 ? "," : "") << std::setw(10) << get(i, j);
			}
			std::cout << std::setw(3) << "]" << std::endl;
		}
		std::cout << "]" << std::endl;
	}

	template<int other_col_num>
	mat<row_num, other_col_num, val_t> dot(const mat<col_num, other_col_num, val_t>& mt) const
	{
		return ::dot(*this, mt);
	}

	mat<row_num, col_num, val_t> rot180() const
	{
		mat<row_num, col_num, val_t> ret;
		for (int r = 0; r < row_num; ++r)
		{
			for (int c = 0; c < col_num; ++c)
			{
				ret.get(r, c) = get(row_num - 1 - r, col_num - 1 - c);
			}
		}
		return ret;
	}

	template<int row_base, int col_base, int row_num_other, int col_num_other>
	void assign(const mat<row_num_other, col_num_other, val_t>& mt_other)
	{
		/* 这里不麻烦了,直接写成运行时 */
		for (int r = 0; r < row_num_other; ++r)
		{
			for (int c = 0; c < col_num_other; ++c)
			{
				if (r + row_base < 0 || c + col_base < 0)
				{
					continue;
				}
				if (r + row_base >= row_num || c + col_base >= col_num)
				{
					break;
				}
				get(r + row_base, c + col_base) = mt_other.get(r, c);
			}
		}
	}

	template<int top_pad, int left_pad, int right_pad, int bottom_pad>
	mat<row_num + top_pad + bottom_pad, col_num + left_pad + right_pad, val_t>
		pad() const
	{
		using mat_ret_t = mat<row_num + top_pad + bottom_pad, col_num + left_pad + right_pad, val_t>;
		mat_ret_t mt_ret;
		mt_ret.assign<top_pad, left_pad>(*this);
		return mt_ret;
	}

	template<int row_span, int col_span>
	mat<row_num + row_span * (row_num - 1), col_num + col_span * (col_num - 1)>
		span() const
	{
		using mat_ret_t = mat<row_num + row_span * (row_num - 1), col_num + col_span * (col_num - 1)>;
		mat_ret_t mt_ret;
		for (int r = 0; r < row_num; ++r)
		{
			for (int c = 0; c < col_num; ++c)
			{
				mt_ret.get(r * (row_span + 1), c * (col_span + 1)) = get(r, c);
			}
		}
		return mt_ret;
	}

	template<int row_base, int col_base, int row_len, int col_len>
	val_t region_max(int& i_row, int& i_col) const
	{
		static_assert(row_base < row_num&& col_base < col_num, "region_max overflow!!!");
		val_t d_max = -1. * DBL_MAX;
		for (int r = row_base; r < row_base + row_len && r < row_num; ++r)
		{
			for (int c = col_base; c < col_base + col_len && c < col_num; ++c)
			{
				if (d_max < get(r, c))
				{
					i_row = r, i_col = c;
					d_max = get(r, c);
				}
			}
		}
		return d_max;
	}

	mat<row_num* col_num, 1, val_t> one_col() const
	{
		mat<row_num* col_num, 1, val_t> ret;
		ret.pval = pval;
		return ret;
	}

	static void print_type()
	{
		printf("<matrix %d * %d>\r\n", row_num, col_num);
	}
};


#endif
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

腾昵猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值