C++元编程——单向深度RNN实现

书接上文,话说两端。通过节点建立RNN有点麻烦,现在又搞了一个深度RNN。就是把单个RNN堆叠起来。废话不多,直接搞代码:

template<int inp_num, int out_num>
struct rnn_node:public cal_chain_node<inp_num, 1, out_num, 1, double>
{
	using inp_type = mat<inp_num, 1, double>;
	using ret_type = mat<out_num, 1, double>;
	inp_type St;
	inp_type dSt;
	cal_chain_node_mult<HeGaussian, nadam, inp_num, 1, inp_num, double> W,U;
	cal_chain_node_bias<nadam, inp_num, 1, double> b;
	cal_chain_node_act<sigmoid, inp_num, 1, double> f;
	
	cal_chain_node_mult<HeGaussian, nadam, inp_num, 1, out_num, double> V;
	cal_chain_node_act<softmax, out_num, 1, double> g;

	static constexpr int inp_num = inp_num;

	rnn_node()
	{
	}

	ret_type forward(const inp_type& X)
	{
		St.assign<0,0>(f.forward(b.forward(U.forward(X) + W.forward(St))));
		return g.forward(V.forward(St));
	}

	inp_type backward(const ret_type& delta)
	{
		auto delta_before_b = (f.backward(dSt) + f.backward(V.backward(g.backward(delta)))) / 2.;
		auto WU = b.backward(delta_before_b);
		dSt.assign<0,0>(W.backward(WU));
		return U.backward(WU);
	}

	void update() 
	{
		W.update();
		U.update();
		b.update();
		f.update();
		V.update();
		g.update();
	}
};

template<typename cur_t, typename val_t, int cur_node_num, int...nodes_num>
struct rnn_type_cal 
{
	using added_t = typename cur_t::template add_tool<cur_node_num, 1>::type;
	using chain_type = typename rnn_type_cal<added_t, val_t, nodes_num...>::chain_type;
};

template<typename cur_t, typename val_t, int cur_node_num>
struct rnn_type_cal<cur_t, val_t, cur_node_num>
{
	using added_t = typename cur_t::template add_tool<cur_node_num, 1>::type;
	using chain_type = added_t;
};

template<typename val_t, int first_node, int second_node, int...nodes_num>
struct rnn_type_def 
{
	using first_chain = cal_chain<val_t, first_node, 1, second_node, 1>;
	using chain_type = typename rnn_type_cal< first_chain, val_t, nodes_num...>::chain_type;
};

template<int first_node, int second_node, int...nodes_num>
struct gen_rnn 
{
	template<int N, typename chain_type>
	static void gen_rnn_chain(chain_type& chn)
	{
		using cur_node_t = rnn_node<first_node, second_node>;
		cur_node_t* cur_node = new cur_node_t;
		chn.set<N>(cur_node);
		gen_rnn<second_node, nodes_num...>::template gen_rnn_chain<N + 1>(chn);
	}
};

template<int first_node, int second_node>
struct gen_rnn<first_node, second_node>
{
	template<int N, typename chain_type>
	static void gen_rnn_chain(chain_type& chn)
	{
		using cur_node_t = rnn_node<first_node, second_node>;
		cur_node_t* cur_node = new cur_node_t;
		chn.set<N>(cur_node);
	}
};

template<typename val_t, int...node_nums>
typename rnn_type_def<val_t, node_nums...>::chain_type make_rnn() 
{
	using chain_type = typename rnn_type_def<val_t, node_nums...>::chain_type;
	chain_type ret;
	gen_rnn<node_nums...>::gen_rnn_chain<0>(ret);
	return ret;
}

对于上一篇文章中的计算链cal_chain增加了一个算法,用于推导增加元素后的类型(add_tool)。计算链全部代码如下:

#ifndef _CAL_CHAIN_HPP_
#define _CAL_CHAIN_HPP_
#include "mat.hpp"
#include "base_function.hpp"

template<int inp_row, int inp_col, int ret_row, int ret_col, typename val_t>
struct cal_chain_node
{
	using base_type = cal_chain_node<inp_row, inp_col, ret_row, ret_col, val_t>;
	using ret_type = mat<ret_row, ret_col, val_t>;
	using inp_type = mat<inp_row, inp_col, val_t>;

	virtual ret_type forward(const inp_type& inp) = 0;
	virtual inp_type backward(const ret_type& delta) = 0;
	virtual void update() = 0;
};

template<typename val_t, int inp_row, int inp_col, int ret_row, int ret_col, int...rest_row_col>
struct cal_chain 
{
	using type = val_t;
	using cur_type = cal_chain_node<inp_row, inp_col, ret_row, ret_col, val_t>;
	cur_type* sp_cur_node;
	using inp_type = typename cur_type::inp_type;
	using nxt_type = cal_chain<val_t, ret_row, ret_col, rest_row_col...>;
	std::shared_ptr<nxt_type> sp_next;
	using ret_type = typename nxt_type::ret_type;

	cal_chain() 
	{
		sp_next = std::make_shared<nxt_type>();
	}

	auto forward(const inp_type& inp)
	{
		return sp_next->forward(sp_cur_node->forward(inp));
	}

	auto backward(const ret_type& ret) 
	{
		return sp_cur_node->backward(sp_next->backward(ret));
	}

	void update() 
	{
		sp_cur_node->update();
		if (sp_next)
			sp_next->update();
	}

	template<int remain, typename set_type>
	cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, rest_row_col...>& set(set_type* sp)
	{
		if constexpr (remain != 0)
		{
			sp_next->set<remain - 1, set_type>(sp);
		}
		if constexpr (remain == 0)
		{
			sp_cur_node = dynamic_cast<cur_type*>(sp);
		}
		return *this;
	}

	template<int insert_row, int insert_col>
	struct add_tool 
	{
		using type = cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, rest_row_col..., insert_row, insert_col>;
	};
};

template<typename val_t, int inp_row, int inp_col, int ret_row, int ret_col>
struct cal_chain<val_t, inp_row, inp_col, ret_row, ret_col>
{
	using type = val_t;
	using cur_type = cal_chain_node<inp_row, inp_col, ret_row, ret_col, val_t>;
	cur_type* sp_cur_node;
	using inp_type = typename cur_type::inp_type;
	using ret_type = typename cur_type::ret_type;

	cal_chain() 
	{
	}

	auto forward(const inp_type& inp)
	{
		return sp_cur_node->forward(inp);
	}

	auto backward(const ret_type& ret)
	{
		return sp_cur_node->backward(ret);
	}

	void update()
	{
		sp_cur_node->update();
	}


	template<int remain, typename set_type>
	cal_chain<val_t, inp_row, inp_col, ret_row, ret_col>& set(set_type* sp)
	{
		static_assert(remain == 0, "over over");
		if constexpr (remain == 0)
			sp_cur_node = dynamic_cast<cur_type*>(sp);
		return *this;
	}

	template<int insert_row, int insert_col>
	struct add_tool
	{
		using type = cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, insert_row, insert_col>;
	};
};

template<typename chain>
struct cal_chain_container :public cal_chain_node<chain::inp_type::r, chain::inp_type::c, chain::ret_type::r, chain::ret_type::c, typename chain::type>
{
	using inp_type = typename chain::inp_type;
	using ret_type = typename chain::ret_type;
	chain*		p;
	cal_chain_container(chain* pv):p(pv)
	{}

	ret_type forward(const inp_type& inp)
	{
		return p->forward(inp);
	}

	inp_type backward(const ret_type& ret)
	{
		return p->backward(ret);
	}

	void update()
	{
		p->update();
	}
};

template<typename val_t, int inp_row, int inp_col, int ret_row, int ret_col, int...rest_row_col>
auto make_chain_node(cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, rest_row_col...>* p)
{
	using chain_type = cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, rest_row_col...>;
	cal_chain_container<chain_type> cc(p);
	return cc;
}

#include "weight_initilizer.hpp"
#include "update_methods.hpp"

template<typename ini_t, template<typename> class update_method_templ, int inp_row, int inp_col, int ret_row, typename val_t>
struct cal_chain_node_mult:public cal_chain_node<inp_row, inp_col, ret_row, inp_col, val_t>
{
	using ret_type = mat<ret_row, inp_col, val_t>;
	using inp_type = mat<inp_row, inp_col, val_t>;
	using weight_type = mat<ret_row, inp_row, val_t>;

	weight_type W;
	weight_type deltaW;
	inp_type pre_inp;

	update_method_templ<weight_type> um;
	double d_num;

	cal_chain_node_mult():d_num(0.)
	{
		weight_initilizer<ini_t>::cal(W);
	}

	virtual ret_type forward(const inp_type& inp) 
	{
		pre_inp = inp;
		return W.dot(inp);
	}

	virtual inp_type backward(const ret_type& delta) 
	{
		auto ret = W.t().dot(delta);
		deltaW = deltaW * d_num + delta.dot(pre_inp.t());
		d_num = d_num + 1.;
		if (d_num > 1e-7) 
		{
			deltaW = deltaW / d_num;
		}
		return ret;
	}

	virtual void update() 
	{
		W.assign<0, 0>(um.update(W, deltaW));
		deltaW = 0.;
		d_num = 0.;
	}
};

template<template<typename> class update_method_templ, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_bias :public cal_chain_node<inp_row, inp_col, inp_row, inp_col, val_t>
{
	using ret_type = mat<inp_row, inp_col, val_t>;
	using inp_type = mat<inp_row, inp_col, val_t>;
	update_method_templ<inp_type> um;

	inp_type b;
	inp_type deltab;
	double d_num;
	cal_chain_node_bias():d_num(0.)
	{}

	virtual ret_type forward(const inp_type& inp)
	{
		return b + (inp);
	}

	virtual inp_type backward(const ret_type& delta)
	{
		deltab = deltab * d_num + delta;
		d_num = d_num + 1.;
		if (d_num > 1e-7)
		{
			deltab = deltab / d_num;
		}
		return delta;
	}

	virtual void update()
	{
		b.assign<0, 0>(um.update(b, deltab));
		deltab = 0.;
		d_num = 0.;
	}
};

#include "activate_function.hpp"
template<template<typename> class activate_func, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_act :public cal_chain_node<inp_row, inp_col, inp_row, inp_col, val_t>
{
	using ret_type = mat<inp_row, inp_col, val_t>;
	using inp_type = mat<inp_row, inp_col, val_t>;
	activate_func<inp_type> act_fun;

	virtual ret_type forward(const inp_type& inp)
	{
		return act_fun.forward(inp);
	}

	virtual inp_type backward(const ret_type& delta)
	{
		return act_fun.backward() * delta;
	}

	virtual void update()
	{
	}
};

#endif

最后运算结果显示,这个多层的RNN也是有学习能力的。由于没有足够多的有关联数据(也比较懒,不想搞),所以就没正经试验过。有兴趣的小朋友们可以用自己的数据试一试。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

腾昵猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值