自然语言处理系列(一)——RNN基础

注: 本文是总结性文章,叙述较为简洁,不适合初学者

一、为什么要有RNN?

普通的MLP无法处理序列信息(如文本、语音等),这是因为序列是不定长的,而MLP的输入层神经元个数是固定的。

二、RNN的结构

普通MLP的结构(以单隐层为例):

在这里插入图片描述

普通RNN(又称Vanilla RNN,接下来都将使用这一说法)的结构(在单隐层MLP的基础上进行改造):

在这里插入图片描述

t t t 时刻隐藏层接收的输入来自于 t − 1 t-1 t1 时刻隐藏层的输出和 t t t 时刻的样例输入。用数学公式表示,就是

h ( t ) = tanh ⁡ ( W h ( t − 1 ) + U x ( t ) + b ) , o ( t ) = V h ( t ) + c , y ^ ( t ) = softmax ( o ( t ) ) h^{(t)}=\tanh(Wh^{(t-1)}+Ux^{(t)}+b),\quad o^{(t)}=Vh^{(t)}+c,\quad \hat{y}^{(t)}=\text{softmax}(o^{(t)}) h(t)=tanh(Wh(t1)+Ux(t)+b),o(t)=Vh(t)+c,y^(t)=softmax(o(t))

训练RNN的过程中,实际上就是在学习 U , V , W , b , c U,V,W,b,c U,V,W,b,c 这些参数。

正向传播后,我们需要计算损失,设时间步 t t t 处求得的损失为 L ( t ) = L ( t ) ( y ^ ( t ) , y ( t ) ) L^{(t)}=L^{(t)}(\hat{y}^{(t)},y^{(t)}) L(t)=L(t)(y^(t),y(t)),则总的损失为 L = ∑ t = 1 T L ( t ) L=\sum_{t=1}^T L^{(t)} L=t=1TL(t)

2.1 BPTT

BPTT(BackPropagation Through Time),通过时间反向传播是RNN训练过程中的一个术语。因为正向传播时是沿着时间流逝的方向进行的,而反向传播则是逆着时间进行的。

为方便后续推导,我们先改进一下符号表述:

h ( t ) = tanh ⁡ ( W h h h ( t − 1 ) + W x h x ( t ) + b ) , o ( t ) = W h o h ( t ) + c , y ^ ( t ) = softmax ( o ( t ) ) h^{(t)}=\tanh(W_{hh}h^{(t-1)}+W_{xh}x^{(t)}+b),\quad o^{(t)}=W_{ho}h^{(t)}+c,\quad \hat{y}^{(t)}=\text{softmax}(o^{(t)}) h(t)=tanh(Whhh(t1)+Wxhx(t)+b),o(t)=Whoh(t)+c,y^(t)=softmax(o(t))

做一个水平方向的 concatenation: W = ( W h h , W x h ) W=(W_{hh},W_{xh}) W=(Whh,Wxh),为简便起见,省略偏置 b b b,则有

h ( t ) = tanh ⁡ ( W ( h ( t − 1 ) x ( t ) ) ) h^{(t)}=\tanh\left(W \begin{pmatrix} h^{(t-1)} \\ x^{(t)} \end{pmatrix} \right) h(t)=tanh(W(h(t1)x(t)))

,接下来我们将关注参数 W W W 的学习。

注意到

∂ h ( t ) ∂ h ( t − 1 ) = tanh ⁡ ′ ( W h h h ( t − 1 ) + W x h x ( t ) ) W h h , ∂ L ∂ W = ∑ t = 1 T ∂ L ( t ) ∂ W \frac{\partial h^{(t)}}{\partial h^{(t-1)}}=\tanh'(W_{hh}h^{(t-1)}+W_{xh}x^{(t)})W_{hh},\quad \frac{\partial L}{\partial W}=\sum_{t=1}^T\frac{\partial L^{(t)}}{\partial W} h(t1)h(t)=tanh(Whhh(t1)+Wxhx(t))Whh,WL=t=1TWL(t)

从而

∂ L ( T ) ∂ W = ∂ L ( T ) ∂ h ( T ) ⋅ ∂ h ( T ) ∂ h ( T − 1 ) ⋯ ∂ h ( 2 ) ∂ h ( 1 ) ⋅ ∂ h ( 1 ) ∂ W = ∂ L ( T ) ∂ h ( T ) ⋅ ∏ t = 2 T ∂ h ( t ) ∂ h ( t − 1 ) ⋅ ∂ h ( 1 ) ∂ W = ∂ L ( T ) ∂ h ( T ) ⋅ ( ∏ t = 2 T tanh ⁡ ′ ( W h h h ( t − 1 ) + W x h x ( t ) ) ) ⋅ W h h T − 1 ⋅ ∂ h ( 1 ) ∂ W \begin{aligned} \frac{\partial L^{(T)}}{\partial W}&=\frac{\partial L^{(T)}}{\partial h^{(T)}}\cdot \frac{\partial h^{(T)}}{\partial h^{(T-1)}}\cdots \frac{\partial h^{(2)}}{\partial h^{(1)}}\cdot\frac{\partial h^{(1)}}{\partial W} \\ &=\frac{\partial L^{(T)}}{\partial h^{(T)}}\cdot \prod_{t=2}^T\frac{\partial h^{(t)}}{\partial h^{(t-1)}}\cdot\frac{\partial h^{(1)}}{\partial W}\\ &=\frac{\partial L^{(T)}}{\partial h^{(T)}}\cdot \left(\prod_{t=2}^T\tanh'(W_{hh}h^{(t-1)}+W_{xh}x^{(t)})\right)\cdot W_{hh}^{T-1} \cdot\frac{\partial h^{(1)}}{\partial W}\\ \end{aligned} WL(T)=h(T)L(T)h(T1)h(T)h(1)h(2)Wh(1)=h(T)L(T)t=2Th(t1)h(t)Wh(1)=h(T)L(T)(t=2Ttanh(Whhh(t1)+Wxhx(t)))WhhT1Wh(1)

因为 tanh ⁡ ′ ( ⋅ ) \tanh'(\cdot) tanh() 几乎总是小于 1 1 1 的,当 T T T 足够大时将会出现梯度消失现象。


假如不采用非线性的激活函数,为简便起见,不妨设激活函数为恒等映射 f ( x ) = x f(x)=x f(x)=x,于是有

∂ L ( T ) ∂ W = ∂ L ( T ) ∂ h ( T ) ⋅ W h h T − 1 ⋅ ∂ h ( 1 ) ∂ W \frac{\partial L^{(T)}}{\partial W}=\frac{\partial L^{(T)}}{\partial h^{(T)}}\cdot W_{hh}^{T-1} \cdot\frac{\partial h^{(1)}}{\partial W} WL(T)=h(T)L(T)WhhT1Wh(1)

  • W h h W_{hh} Whh 的最大奇异值大于 1 1 1 时,会出现梯度爆炸。
  • W h h W_{hh} Whh 的最大奇异值小于 1 1 1 时,会出现梯度消失。

三、RNN的分类

按照输入和输出的结构可以对RNN进行如下分类:

  • 1 vs N(vec2seq):Image Captioning;
  • N vs 1(seq2vec):Sentiment Analysis;
  • N vs M(seq2seq):Machine Translation;
  • N vs N(seq2seq):Sequence Labeling(POS Tagging)

在这里插入图片描述

注意 1 vs 1 是传统的MLP。

若按照内部构造进行分类则会得到:

  • RNN、Bi-RNN、…
  • LSTM、Bi-LSTM、…
  • GRU、Bi-GRU、…

四、Vanilla RNN的优缺点

优点:

  • 可以处理不定长的序列;
  • 计算时会考虑历史信息;
  • 权重沿时间方向上是共享的;
  • 模型大小不会随着输入大小增加而改变。

缺点:

  • 计算效率低;
  • 梯度会消失/爆炸(后续将知道,避免梯度爆炸可采用梯度裁剪,避免梯度消失可换用其他的RNN结构,如LSTM);
  • 无法处理长序列(即不具备长记忆性);
  • 无法利用未来的输入(Bi-RNN可解决)。

五、Bidirectional RNN

许多时候,我们要输出的 y ( t ) y^{(t)} y(t) 可能依赖于整个序列,因此需要使用双向RNN(BRNN)。BRNN结合了时间上从序列起点开始移动的RNN和从序列末尾开始移动的RNN。两个RNN互相独立不共享权重:

在这里插入图片描述
相应的计算方式变为:

h ( t ) = tanh ⁡ ( W 1 h ( t − 1 ) + U 1 x ( t ) + b 1 ) g ( t ) = tanh ⁡ ( W 2 h ( t − 1 ) + U 2 x ( t ) + b 2 ) o ( t ) = V ( h ( t ) ; g ( t ) ) + c y ^ ( t ) = softmax ( o ( t ) ) \begin{aligned} &h^{(t)}=\tanh(W_1h^{(t-1)}+U_1x^{(t)}+b_1) \\ &g^{(t)}=\tanh(W_2h^{(t-1)}+U_2x^{(t)}+b_2) \\ &o^{(t)}=V(h^{(t)};g^{(t)})+c \\ &\hat{y}^{(t)}=\text{softmax}(o^{(t)}) \\ \end{aligned} h(t)=tanh(W1h(t1)+U1x(t)+b1)g(t)=tanh(W2h(t1)+U2x(t)+b2)o(t)=V(h(t);g(t))+cy^(t)=softmax(o(t))

其中 ( h ( t ) ; g ( t ) ) (h^{(t)};g^{(t)}) (h(t);g(t)) 代表将两个列向量 h ( t ) h^{(t)} h(t) g ( t ) g^{(t)} g(t) 进行纵向连接。

事实上,若将 V V V 按列分块,则上述的第三个等式还可写成:

o ( t ) = V ( h ( t ) ; g ( t ) ) + c = ( V 1 , V 2 ) ( h ( t ) g ( t ) ) + c = V 1 h ( t ) + V 2 g ( t ) + c o^{(t)}=V(h^{(t)};g^{(t)})+c= (V_1,V_2) \begin{pmatrix} h^{(t)} \\ g^{(t)} \end{pmatrix}+c=V_1h^{(t)}+V_2g^{(t)}+c o(t)=V(h(t);g(t))+c=(V1,V2)(h(t)g(t))+c=V1h(t)+V2g(t)+c

训练 BRNN 的过程实际就是在学习 U 1 , U 2 , V , W 1 , W 2 , b 1 , b 2 , c U_1,U_2,V,W_1,W_2,b_1,b_2,c U1,U2,V,W1,W2,b1,b2,c 这些参数。

六、Stacked RNN

堆叠RNN又称多层RNN或深度RNN,即由多个隐藏层组成。以双隐层单向RNN为例,其结构如下:

在这里插入图片描述

相应的计算过程如下:

h ( t ) = tanh ⁡ ( W h h h ( t − 1 ) + W x h x ( t ) + b h ) z ( t ) = tanh ⁡ ( W z z z ( t − 1 ) + W h z h ( t ) + b z ) o ( t ) = W z o z ( t ) + b o y ^ ( t ) = softmax ( o ( t ) ) \begin{aligned} &h^{(t)}=\tanh(W_{hh}h^{(t-1)}+W_{xh}x^{(t)}+b_h) \\ &z^{(t)}=\tanh(W_{zz}z^{(t-1)}+W_{hz}h^{(t)}+b_z) \\ &o^{(t)}=W_{zo}z^{(t)}+b_o \\ &\hat{y}^{(t)}=\text{softmax}(o^{(t)}) \\ \end{aligned} h(t)=tanh(Whhh(t1)+Wxhx(t)+bh)z(t)=tanh(Wzzz(t1)+Whzh(t)+bz)o(t)=Wzoz(t)+boy^(t)=softmax(o(t))

  • 12
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Iareges

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值