RNN学习笔记(六)-GRU,LSTM 代码实现

本文深入探讨GRU和LSTM的代码实现,以2-gram语言模型为例,介绍了网络结构,包括GRU和LSTM的输入层、隐层结构,并通过详细推导解释了梯度计算过程,特别关注了bptt的部分。
摘要由CSDN通过智能技术生成

RNN学习笔记(六)-GRU,LSTM 代码实现

在这篇文章里,我们将讨论GRU/LSTM的代码实现。在这里,我们仍然沿用RNN学习笔记(五)-RNN 代码实现里的例子,使用GRU/LSTM网络建立一个2-gram的语言模型。
项目源码:https://github.com/rtygbwwwerr/RNN
参考项目:https://github.com/dennybritz/rnn-tutorial-gru-lstm

1.网络结构

为了解决当词典中的words数量很大时,输入向量过长的问题,我们在输入层和隐层之间引入了Embedding Layer,通过该层,输入的one-hot将被转换为word的Embedding vector。

1.1 GRU网络

这里写图片描述

1.2 LSTM网络

略。

2.代码实现

这里我们重点讨论bptt部分(*:“ ”表示elemwise乘法运算)。对于GRU网络来说,有

zrhstz(o)(t)ot=σ(xtUz+st1Wz)=σ(xtUr+st1Wr)=tanh(xtUh+(st1r)Wh)=(1z)h+zst1=stV+c=softmax(z(o)(t))

softmax(x)=softmax(x)[1softmax(x)]
输出层节点的输入值 z(o)k(t) 导数如下:
δ(o)k(t)=Ltz(o)k(t)
=ok(t)1
写成向量形式为:
δ(o)(t)=o(t)1
LtV=δ(o)k(t)z(o)k(t)V=[ok(t)1]st
可以看到,这一层的导数与常规RNN是一致的。

从隐层开始,导数将有所不同,我们先来看下单个GRU网络节点结构:
这里写图片描述

这里,先对符号做一下约定:
iz(t)=xtUz+s(t1)Wz :t时刻update gate 对应的输入
ir(t)=xtUr+s(t1)Wr :t时刻rest gate 对应的输入
ih(t)=xtUh+(s(t1)r(t))Wh :t时刻隐单元对应的输入
io(t)=(1z(t))h(t)+z(t)s(t1) :t时刻output gate对应的输入
f(io(t))=io

评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值