书接上文,话说两端。通过节点建立RNN有点麻烦,现在又搞了一个深度RNN。就是把单个RNN堆叠起来。废话不多,直接搞代码:
template<int inp_num, int out_num>
struct rnn_node:public cal_chain_node<inp_num, 1, out_num, 1, double>
{
using inp_type = mat<inp_num, 1, double>;
using ret_type = mat<out_num, 1, double>;
inp_type St;
inp_type dSt;
cal_chain_node_mult<HeGaussian, nadam, inp_num, 1, inp_num, double> W,U;
cal_chain_node_bias<nadam, inp_num, 1, double> b;
cal_chain_node_act<sigmoid, inp_num, 1, double> f;
cal_chain_node_mult<HeGaussian, nadam, inp_num, 1, out_num, double> V;
cal_chain_node_act<softmax, out_num, 1, double> g;
static constexpr int inp_num = inp_num;
rnn_node()
{
}
ret_type forward(const inp_type& X)
{
St.assign<0,0>(f.forward(b.forward(U.forward(X) + W.forward(St))));
return g.forward(V.forward(St));
}
inp_type backward(const ret_type& delta)
{
auto delta_before_b = (f.backward(dSt) + f.backward(V.backward(g.backward(delta)))) / 2.;
auto WU = b.backward(delta_before_b);
dSt.assign<0,0>(W.backward(WU));
return U.backward(WU);
}
void update()
{
W.update();
U.update();
b.update();
f.update();
V.update();
g.update();
}
};
template<typename cur_t, typename val_t, int cur_node_num, int...nodes_num>
struct rnn_type_cal
{
using added_t = typename cur_t::template add_tool<cur_node_num, 1>::type;
using chain_type = typename rnn_type_cal<added_t, val_t, nodes_num...>::chain_type;
};
template<typename cur_t, typename val_t, int cur_node_num>
struct rnn_type_cal<cur_t, val_t, cur_node_num>
{
using added_t = typename cur_t::template add_tool<cur_node_num, 1>::type;
using chain_type = added_t;
};
template<typename val_t, int first_node, int second_node, int...nodes_num>
struct rnn_type_def
{
using first_chain = cal_chain<val_t, first_node, 1, second_node, 1>;
using chain_type = typename rnn_type_cal< first_chain, val_t, nodes_num...>::chain_type;
};
template<int first_node, int second_node, int...nodes_num>
struct gen_rnn
{
template<int N, typename chain_type>
static void gen_rnn_chain(chain_type& chn)
{
using cur_node_t = rnn_node<first_node, second_node>;
cur_node_t* cur_node = new cur_node_t;
chn.set<N>(cur_node);
gen_rnn<second_node, nodes_num...>::template gen_rnn_chain<N + 1>(chn);
}
};
template<int first_node, int second_node>
struct gen_rnn<first_node, second_node>
{
template<int N, typename chain_type>
static void gen_rnn_chain(chain_type& chn)
{
using cur_node_t = rnn_node<first_node, second_node>;
cur_node_t* cur_node = new cur_node_t;
chn.set<N>(cur_node);
}
};
template<typename val_t, int...node_nums>
typename rnn_type_def<val_t, node_nums...>::chain_type make_rnn()
{
using chain_type = typename rnn_type_def<val_t, node_nums...>::chain_type;
chain_type ret;
gen_rnn<node_nums...>::gen_rnn_chain<0>(ret);
return ret;
}
对于上一篇文章中的计算链cal_chain增加了一个算法,用于推导增加元素后的类型(add_tool)。计算链全部代码如下:
#ifndef _CAL_CHAIN_HPP_
#define _CAL_CHAIN_HPP_
#include "mat.hpp"
#include "base_function.hpp"
template<int inp_row, int inp_col, int ret_row, int ret_col, typename val_t>
struct cal_chain_node
{
using base_type = cal_chain_node<inp_row, inp_col, ret_row, ret_col, val_t>;
using ret_type = mat<ret_row, ret_col, val_t>;
using inp_type = mat<inp_row, inp_col, val_t>;
virtual ret_type forward(const inp_type& inp) = 0;
virtual inp_type backward(const ret_type& delta) = 0;
virtual void update() = 0;
};
template<typename val_t, int inp_row, int inp_col, int ret_row, int ret_col, int...rest_row_col>
struct cal_chain
{
using type = val_t;
using cur_type = cal_chain_node<inp_row, inp_col, ret_row, ret_col, val_t>;
cur_type* sp_cur_node;
using inp_type = typename cur_type::inp_type;
using nxt_type = cal_chain<val_t, ret_row, ret_col, rest_row_col...>;
std::shared_ptr<nxt_type> sp_next;
using ret_type = typename nxt_type::ret_type;
cal_chain()
{
sp_next = std::make_shared<nxt_type>();
}
auto forward(const inp_type& inp)
{
return sp_next->forward(sp_cur_node->forward(inp));
}
auto backward(const ret_type& ret)
{
return sp_cur_node->backward(sp_next->backward(ret));
}
void update()
{
sp_cur_node->update();
if (sp_next)
sp_next->update();
}
template<int remain, typename set_type>
cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, rest_row_col...>& set(set_type* sp)
{
if constexpr (remain != 0)
{
sp_next->set<remain - 1, set_type>(sp);
}
if constexpr (remain == 0)
{
sp_cur_node = dynamic_cast<cur_type*>(sp);
}
return *this;
}
template<int insert_row, int insert_col>
struct add_tool
{
using type = cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, rest_row_col..., insert_row, insert_col>;
};
};
template<typename val_t, int inp_row, int inp_col, int ret_row, int ret_col>
struct cal_chain<val_t, inp_row, inp_col, ret_row, ret_col>
{
using type = val_t;
using cur_type = cal_chain_node<inp_row, inp_col, ret_row, ret_col, val_t>;
cur_type* sp_cur_node;
using inp_type = typename cur_type::inp_type;
using ret_type = typename cur_type::ret_type;
cal_chain()
{
}
auto forward(const inp_type& inp)
{
return sp_cur_node->forward(inp);
}
auto backward(const ret_type& ret)
{
return sp_cur_node->backward(ret);
}
void update()
{
sp_cur_node->update();
}
template<int remain, typename set_type>
cal_chain<val_t, inp_row, inp_col, ret_row, ret_col>& set(set_type* sp)
{
static_assert(remain == 0, "over over");
if constexpr (remain == 0)
sp_cur_node = dynamic_cast<cur_type*>(sp);
return *this;
}
template<int insert_row, int insert_col>
struct add_tool
{
using type = cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, insert_row, insert_col>;
};
};
template<typename chain>
struct cal_chain_container :public cal_chain_node<chain::inp_type::r, chain::inp_type::c, chain::ret_type::r, chain::ret_type::c, typename chain::type>
{
using inp_type = typename chain::inp_type;
using ret_type = typename chain::ret_type;
chain* p;
cal_chain_container(chain* pv):p(pv)
{}
ret_type forward(const inp_type& inp)
{
return p->forward(inp);
}
inp_type backward(const ret_type& ret)
{
return p->backward(ret);
}
void update()
{
p->update();
}
};
template<typename val_t, int inp_row, int inp_col, int ret_row, int ret_col, int...rest_row_col>
auto make_chain_node(cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, rest_row_col...>* p)
{
using chain_type = cal_chain<val_t, inp_row, inp_col, ret_row, ret_col, rest_row_col...>;
cal_chain_container<chain_type> cc(p);
return cc;
}
#include "weight_initilizer.hpp"
#include "update_methods.hpp"
template<typename ini_t, template<typename> class update_method_templ, int inp_row, int inp_col, int ret_row, typename val_t>
struct cal_chain_node_mult:public cal_chain_node<inp_row, inp_col, ret_row, inp_col, val_t>
{
using ret_type = mat<ret_row, inp_col, val_t>;
using inp_type = mat<inp_row, inp_col, val_t>;
using weight_type = mat<ret_row, inp_row, val_t>;
weight_type W;
weight_type deltaW;
inp_type pre_inp;
update_method_templ<weight_type> um;
double d_num;
cal_chain_node_mult():d_num(0.)
{
weight_initilizer<ini_t>::cal(W);
}
virtual ret_type forward(const inp_type& inp)
{
pre_inp = inp;
return W.dot(inp);
}
virtual inp_type backward(const ret_type& delta)
{
auto ret = W.t().dot(delta);
deltaW = deltaW * d_num + delta.dot(pre_inp.t());
d_num = d_num + 1.;
if (d_num > 1e-7)
{
deltaW = deltaW / d_num;
}
return ret;
}
virtual void update()
{
W.assign<0, 0>(um.update(W, deltaW));
deltaW = 0.;
d_num = 0.;
}
};
template<template<typename> class update_method_templ, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_bias :public cal_chain_node<inp_row, inp_col, inp_row, inp_col, val_t>
{
using ret_type = mat<inp_row, inp_col, val_t>;
using inp_type = mat<inp_row, inp_col, val_t>;
update_method_templ<inp_type> um;
inp_type b;
inp_type deltab;
double d_num;
cal_chain_node_bias():d_num(0.)
{}
virtual ret_type forward(const inp_type& inp)
{
return b + (inp);
}
virtual inp_type backward(const ret_type& delta)
{
deltab = deltab * d_num + delta;
d_num = d_num + 1.;
if (d_num > 1e-7)
{
deltab = deltab / d_num;
}
return delta;
}
virtual void update()
{
b.assign<0, 0>(um.update(b, deltab));
deltab = 0.;
d_num = 0.;
}
};
#include "activate_function.hpp"
template<template<typename> class activate_func, int inp_row, int inp_col, typename val_t>
struct cal_chain_node_act :public cal_chain_node<inp_row, inp_col, inp_row, inp_col, val_t>
{
using ret_type = mat<inp_row, inp_col, val_t>;
using inp_type = mat<inp_row, inp_col, val_t>;
activate_func<inp_type> act_fun;
virtual ret_type forward(const inp_type& inp)
{
return act_fun.forward(inp);
}
virtual inp_type backward(const ret_type& delta)
{
return act_fun.backward() * delta;
}
virtual void update()
{
}
};
#endif
最后运算结果显示,这个多层的RNN也是有学习能力的。由于没有足够多的有关联数据(也比较懒,不想搞),所以就没正经试验过。有兴趣的小朋友们可以用自己的数据试一试。