C++元编程——双向RNN

搞了一个双向的RNN,按照网上介绍的双向RNN实现出来的,但是试验结果确非常奇葩。希望有识之士能够加以解答。先上计算图:

一层包含两个方向的横向节点,最终输出是根据两个节点的输出加权得到。两个节点对于训练数据的输入顺序要求不一样,正向节点要求数据从开始到结束;反向节点是从结束到开始。 

接下来展示试验代码:

#include <conio.h>
#include "cal_chain.hpp"
#include "rnn.hpp"


int main(int argc, char** argv) 
{
	using dup_rnn_type = dup_rnn_node<3, 8, 2>;
	using dup_rnn_type1 = dup_rnn_node<8, 16, 2>;
	using dup_rnn_type2 = dup_rnn_node<16, 2, 2>;
	dup_rnn_type dr;
	dup_rnn_type1 dr1;
	dup_rnn_type2 dr2;
	dup_rnn_type::inp_type vec_inp{
		{.3, .7, .0}
		,{.5, .5, .0}
		, {.7, .3, .0}
	};
	dup_rnn_type2::ret_type vec_ret{
		{.4, .6, .4}
		,{.8, .2, .1}
		, {.1, .9, .8}
	};
	for (int i = 0; ; ++i)
	{
		auto vec_out = dr2.forward(dr1.forward(dr.forward(vec_inp)));
		dup_rnn_type2::ret_type delta = vec_out - vec_ret;
		if (i % 6000 == 0)
		{
			vec_ret.print();
			vec_out.print();
			delta.print();
			_getch();
		}
		dr.backward(dr1.backward(dr2.backward(delta)));
		dr2.update();
		dr1.update();
		dr.update();
	}
	return 0;
}

这个试验代码定义了一个3层的DRNN,输入输出分别是3->8->16->2。最初输入是3*1矩阵,一共有3个表示一个组;第二层输入是8*1矩阵;第三层输入是16*1矩阵;输出是2*1矩阵。

最终结果如下:

 可以看到,训练只要训练6000次结果就稳定了,但是问题在于出来的结果和期望结果有一定的差距。可以看出输出结果是各个输入数据的均值,这个就非常尴尬了。我觉得这和我的训练方式不当可能有关系,下次可以试验每个输入都进行多次训练,直到结果稳定再进行下一个输入的训练。

下面展示具体实现代码,首先是对计算链进行了更新,以保证加权运算和偏移运算可以正确执行,代码如下:

#ifndef _CAL_CHAIN_HPP_
#define _CAL_CHAIN_HPP_
#include "mat.hpp"
#include "base_function.hpp"

template<int inp_row, int inp_col, int ret_row, int ret_col, typename val_t>
struct cal_chain_node
{
	using base_type = cal_chain_node<inp_row, inp_col, ret_row, ret_col, val_t>;
	using ret_type = mat<ret_row, ret_col, val_t>;
	using inp_type = mat<inp_row, inp_col, val_t>;

	virtual ret_type forward(const inp_type& inp) = 0;
	virtual inp_type backward(const ret_type& delta) = 0;
	virtual void update() = 0;
};

template<typename val_t, int inp_row, int inp_col, int ret_row, int ret_col, int...rest_row_col>
struct cal_chain 
{
	using type = val_t;
	using cur_type = cal_chain_node<inp_row, inp_col, ret_row, ret_col, val_t>;
	cur_type* sp_cur_node;
	using inp_type = typename cur_type::inp_type;
	using nxt_type = cal_chain&l
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

腾昵猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值