C++元编程——计算链和RNN

反向传播时候有个计算链,误差传播时也是反向走过各个计算链,所以这个计算链的概念很重要。那么层间单向的RNN计算链可以表现为下图:

 大写字母W、U、V表示点积运算,B是偏移运算,f和g是激活运算,+是相加运算。反向传播就可以从链上看出来结果,前一个误差经过当前运算得到当前误差,同时得到当前运算参数的偏导数,并更新当前运算参数,如此往复向前进行更新。如果遇到分支,可以想像成两次输出的误差,可以求均值也可以分两次训练。我采用的是求均值然后再更新。

下面是这个计算链的实现cal_chain.hpp:

#ifndef _CAL_CHAIN_HPP_
#define _CAL_CHAIN_HPP_
#include "mat.hpp"
#include "base_function.hpp"

template<int inp_row, int inp_col, int ret_row, int ret_col, typename val_t>
struct cal_chain_node
{
	using base_type = cal_chain_node<inp_row, inp_col, ret_row, ret_col, val_t>;
	using ret_type = mat<ret_row, ret_col, val_t>;
	using inp_type = mat<inp_row, inp_col, val_t>;

	virtual ret_type forward(const inp_type& inp) = 0;
	virtual inp_type backward(const ret_type& delta) = 0;
	virtual void update() = 0;
};

template<typename val_t, int inp_row, int inp_col, int ret_row, int ret_col, int...rest_row_col>
struct cal_chain 
{
	using type = val_t;
	using cur_type = cal_chain_node<inp_row, inp_col, ret_row, ret_col, val_t>;
	cur_type* sp_cur_node;
	using inp_type = typename cur_type::inp_type;
	using nxt_type = cal_chain<val_t, ret_row, ret_col, rest_row_col...>;
	std::shared_ptr<nxt_type> sp_next;
	using ret_type = typename nxt_type::ret_type;

	cal_chain() 
	{
		sp_next = std::make_shared<nxt_type>();
	}

	auto forward(const inp_type& inp)
	{
		return sp_next->forward(sp_cur_node->forward(inp));
	}

	auto backward(const ret_type& ret) 
	{
		return sp_cur_node->backward(sp_next->backward(ret));
	}

	void update() 
	{
		sp_cur_node->update();
		if (sp_next)
			sp_next->update();
	}

	template<int remain, typename set_type>
	cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, rest_row_col...>& set(set_type* sp)
	{
		if constexpr (remain != 0)
		{
			sp_next->set<remain - 1, set_type>(sp);
		}
		if constexpr (remain == 0)
		{
			sp_cur_node = dynamic_cast<cur_type*>(sp);
		}
		return *this;
	}
};

template<typename val_t, int inp_row, int inp_col, int ret_row, int ret_col>
struct cal_chain<val_t, inp_row, inp_col, ret_row, ret_col>
{
	using type = val_t;
	using cur_type = cal_chain_node<inp_row, inp_col, ret_row, ret_col, val_t>;
	cur_type* sp_cur_node;
	using inp_type = typename cur_type::inp_type;
	using ret_type = typename cur_type::ret_type;

	cal_chain() 
	{
	}

	auto forward(const inp_type& inp)
	{
		return sp_cur_node->forward(inp);
	}

	auto backward(const ret_type& ret)
	{
		return sp_cur_node->backward(ret);
	}

	void update()
	{
		sp_cur_node->update();
	}


	template<int remain, typename set_type>
	cal_chain<val_t, inp_row, inp_col, ret_row, ret_col>& set(set_type* sp)
	{
		static_assert(remain == 0, "over over");
		if constexpr (remain == 0)
			sp_cur_node = dynamic_cast<cur_type*>(sp);
		return *this;
	}
};

template<typename chain>
struct cal_chain_container :public cal_chain_node<chain::inp_type::r, chain::inp_type::c, chain::ret_type::r, chain::ret_type::c, typename chain::type>
{
	using inp_type = typename chain::inp_type;
	using ret_type = typename chain::ret_type;
	chain*		p;
	cal_chain_container(chain* pv):p(pv)
	{}

	ret_type forward(const inp_type& inp)
	{
		return p->forward(inp);
	}

	inp_type backward(const ret_type& ret)
	{
		return p->backward(ret);
	}

	void update()
	{
		p->update();
	}
};

template<typename val_t, int inp_row, int inp_col, int ret_row, int ret_col, int...rest_row_col>
auto make_chain_node(cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, rest_row_col...>* p)
{
	using chain_type = cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, rest_row_col...>;
	cal_chain_container<chain_type> cc(p);
	return cc;
}

#include "weight_initilizer.hpp"
#include "update_methods.hpp"

template<typename ini_t, template<typename> class update_method_templ, int inp_row, int inp_col, int ret_row, typename val_t>
struct cal_chain_node_mult:public cal_chain_node<inp_row, inp_col, ret_row, inp_col, val_t>
{
	using ret_type = mat<ret_row, inp_col, val_t>;
	using inp_type = mat<inp_row, inp_col, val_t>;
	using weight_type = mat<ret_row, inp_row, val_t>;

	weight_type W;
	weight_type deltaW;
	inp_type pre_inp;

	update_method_templ<weight_type> um;
	double d_num;

	cal_chain_node_mult():d_num(0.)
	{
		weight_initilizer<ini_t>::cal(W);
	}

	virtual ret_type forward(const inp_type& inp) 
	{
		pre_inp = inp;
		return W.dot(inp);
	}

	virtual inp_type backward(const ret_type& delta) 
	{
		auto ret = W.t().dot(delta);
		deltaW = deltaW * d_num + delta.dot(pre_inp.t());
		d_num = d_num + 1.;
		if (d_num > 1e-7) 
		{
			deltaW = deltaW / d_num;
		}
		return ret;
	}

	virtual void update() 
	{
		W.assign<0, 0>(um.update(W, deltaW));
		deltaW = 0.;
		d_num = 0.;
	}
};

template<template<typename> class update_method_templ, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_bias :public cal_chain_node<inp_row, inp_col, inp_row, inp_col, val_t>
{
	using ret_type = mat<inp_row, inp_col, val_t>;
	using inp_type = mat<inp_row, inp_col, val_t>;
	update_method_templ<inp_type> um;

	inp_type b;
	inp_type deltab;
	double d_num;
	cal_chain_node_bias():d_num(0.)
	{}

	virtual ret_type forward(const inp_type& inp)
	{
		return b + (inp);
	}

	virtual inp_type backward(const ret_type& delta)
	{
		deltab = deltab * d_num + delta;
		d_num = d_num + 1.;
		if (d_num > 1e-7)
		{
			deltab = deltab / d_num;
		}
		return delta;
	}

	virtual void update()
	{
		b.assign<0, 0>(um.update(b, deltab));
		deltab = 0.;
		d_num = 0.;
	}
};

#include "activate_function.hpp"
template<template<typename> class activate_func, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_act :public cal_chain_node<inp_row, inp_col, inp_row, inp_col, val_t>
{
	using ret_type = mat<inp_row, inp_col, val_t>;
	using inp_type = mat<inp_row, inp_col, val_t>;
	activate_func<inp_type> act_fun;

	virtual ret_type forward(const inp_type& inp)
	{
		return act_fun.forward(inp);
	}

	virtual inp_type backward(const ret_type& delta)
	{
		return act_fun.backward() * delta;
	}

	virtual void update()
	{
	}
};

#endif

下面看一看用这个计算链实现RNN:

#include <conio.h>
#include "cal_chain.hpp"

template<int inp_num, int out_num>
struct rnn_node
{
	using inp_type = mat<inp_num, 1, double>;
	using ret_type = mat<out_num, 1, double>;
	inp_type St;
	inp_type dSt;
	cal_chain_node_mult<HeGaussian, nadam, inp_num, 1, inp_num, double> W,U;
	cal_chain_node_bias<nadam, inp_num, 1, double> b;
	cal_chain_node_act<sigmoid, inp_num, 1, double> f;
	
	cal_chain_node_mult<HeGaussian, nadam, inp_num, 1, out_num, double> V;
	cal_chain_node_act<softmax, out_num, 1, double> g;

	rnn_node()
	{
	}

	ret_type forward(const inp_type& X)
	{
		St.assign<0,0>(f.forward(b.forward(U.forward(X) + W.forward(St))));
		return g.forward(V.forward(St));
	}

	inp_type backward(const ret_type& delta)
	{
		auto delta_before_b = f.backward(dSt) + f.backward(V.backward(g.backward(delta)));
		auto WU = b.backward(delta_before_b);
		dSt.assign<0,0>(W.backward(WU));
		return U.backward(WU);
	}

	void update() 
	{
		W.update();
		U.update();
		b.update();
		f.update();
		V.update();
		g.update();
	}
};

int main(int argc, char** argv) 
{
	mat<3, 1, double> mm1{.1,.2,.3};
	rnn_node<3, 2> r;
	for (int i = 0; ; ++i)
	{
		auto ret = r.forward(mm1);
		if (i % 10000 == 0)
		{
			ret.print();
			_getch();
		}

		mat<2, 1, double> mm2{.4,.6};
		r.backward(ret - mm2);
		r.update();
	}
	return 0;
}

这个用法当然不是RNN的常规用法。真正的RNN训练应该用一段固定长度的序列进行计算。最好有层内正反向两个传播节点。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

腾昵猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值