RNN和LSTM详解

1. Recurrent Neural Networks(RNN)

1.1 模型

在这里插入图片描述
h t = t a n h [ W h x X t + W h h h t − 1 + b h ] h_t = tanh[W_{hx}X_t + W_{hh}h_{t-1}+b_h] ht=tanh[WhxXt+Whhht1+bh]
z t = f ( W h y h t + b z ) z_t=f(W_{hy}h_t+b_z) zt=f(Whyht+bz)

  • t a n h ( v ) = e x p ( 2 v ) − 1 e x p ( 2 v ) + 1 tanh(v) = \frac{exp(2v)-1}{exp(2v)+1} tanh(v)=exp(2v)+1exp(2v)1
  • W h h , W x h , W h y W_{hh},W_{xh},W_{hy} Whh,Wxh,Why都是可训练的权重矩阵。
  • b h , b z b_h,b_z bh,bz都是可训练的偏差向量。
  • X t X_t Xt z t z_t zt分别是时间 t t t的输入和输出。

1.2 损失函数

L τ ( θ ) = ∑ t ∈ τ L ( y t , z t ) L_\tau(\theta) = \sum_{t\in\tau}L(y_t,z_t) Lτ(θ)=tτL(yt,zt)
这里的 τ \tau τ是输出序列。

1.3 不同形态的RNN

在这里插入图片描述
应用场景:

  • One-to-many: image captioning;
  • Many-to-one: text sentiment classification;
  • Many-to-many: machine translation.

1.4 多层RNN

回想一下单层RNN:
h t = t a n h [ W h x X t + W h h h t − 1 + b h ] = t a n h [ W ( X t h t − 1 1 ) ] h_t = tanh[W_{hx}X_t + W_{hh}h_{t-1}+b_h]=tanh\begin{bmatrix}W\begin{pmatrix}X_t\\h_{t-1}\\1\end{pmatrix}\end{bmatrix} ht=tanh[WhxXt+Whhht1+bh]=tanhWXtht11

多层RNN是单层RNN堆叠而来的:
在这里插入图片描述

h t l = t a n h [ W ( h t l − 1 h t − 1 1 ) ] h_t^l =tanh\begin{bmatrix}W\begin{pmatrix}h_t^{l-1}\\h_{t-1}\\1\end{pmatrix}\end{bmatrix} htl=tanhWhtl1ht11

高层的隐含状态 h t l h_t^l htl由老的状态 h t − 1 l h_{t-1}^l ht1l和低层的隐含状态 h t ( l − 1 ) h_t^(l-1) ht(l1)决定。

1.5 RNN存在的问题

普通RNN的一个显著缺点是,当序列长度很大时,RNN难以捕获序列数据中的长依赖项。这有时是梯度消失/爆炸造成的。
在下面的例子中,计算 ∂ L τ ∂ h 1 \frac{\partial L_\tau}{\partial h_1} h1Lτ时,根据链式求导法则,我们需要计算 ∏ t = 1 3 ( ∂ h t + 1 ∂ h t ) \prod_{t=1}^3(\frac{\partial h_{t+1}}{\partial h_t}) t=13(htht+1)
在这里插入图片描述
如果序列很长,这个乘积将是许多雅可比矩阵的乘积,这通常会得到指数大或指数小的奇异值。

2. LSTM/GRU

2.1 概述

先回顾一下单层RNN:
h t = t a n h [ W h x X t + W h h h t − 1 + b h ] = t a n h [ W ( X t h t − 1 1 ) ] h_t = tanh[W_{hx}X_t + W_{hh}h_{t-1}+b_h]=tanh\begin{bmatrix}W\begin{pmatrix}X_t\\h_{t-1}\\1\end{pmatrix}\end{bmatrix} ht=tanh[WhxXt+Whhht1+bh]=tanhWXtht11

对比LSTM:
( i t f t o t c t ) = ( σ σ σ t a n h ) W ( h t − 1 x t 1 ) \begin{pmatrix}i_t\\f_t\\o_t\\c_t\end{pmatrix}=\begin{pmatrix}\sigma\\\sigma\\\sigma\\tanh\end{pmatrix}W\begin{pmatrix}h_{t-1}\\x_t\\1\end{pmatrix} itftotct=σσσtanhWht1xt1

其中, σ \sigma σ是sigmoid函数。

LSTM可以删除或者添加信息到状态,并被叫“门”的结构(包括遗忘门、输入门、输出门)所限制。
在这里插入图片描述

2.2 遗忘门(Forget gate)

在这里插入图片描述

功能:保存旧的信息
f t = σ [ W f ( X t h t − 1 1 ) ] f_t =\sigma\begin{bmatrix}W_f\begin{pmatrix}X_t\\h_{t-1}\\1\end{pmatrix}\end{bmatrix} ft=σWfXtht11

理想情况下,遗忘门的输出具有接近二进制的值,例如,当 f t f_t ft的输出接近1时可能表明输入序列中存在某个特征。

2.3 输入门(Input gate)

在这里插入图片描述
功能:更新记忆

i t = σ [ W i ( X t h t − 1 1 ) ] i_t =\sigma\begin{bmatrix}W_i\begin{pmatrix}X_t\\h_{t-1}\\1\end{pmatrix}\end{bmatrix} it=σWiXtht11
c ˉ t = t a n h [ W c ( X t h t − 1 1 ) ] \bar c_t=tanh\begin{bmatrix}W_c\begin{pmatrix}X_t\\h_{t-1}\\1\end{pmatrix}\end{bmatrix} cˉt=tanhWcXtht11

2.4 输入门和遗忘门的合并

在这里插入图片描述
c t = f t ⊙ c t − 1 + i t ⊙ c ˉ t c_t=f_t\odot c_{t-1}+i_t \odot \bar c_t ct=ftct1+itcˉt

⊙ \odot 表示两个矩阵对应位置元素进行乘积

2.4 输出门(Output gate)

在这里插入图片描述
功能:决定有多少记忆 c t c_t ct影响输出 h t h_t ht

o t = σ [ W o ( X t h t − 1 1 ) ] o_t =\sigma\begin{bmatrix}W_o\begin{pmatrix}X_t\\h_{t-1}\\1\end{pmatrix}\end{bmatrix} ot=σWoXtht11

h t = o t ⊙ t a n h ( c t ) h_t=o_t \odot tanh(c_t) ht=ottanh(ct)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值