循环神经网络RNN模型

一、基础概念

1.1 结构

![[Pasted image 20230605110540.png]]

1.2 RNN与传统NN对比

![[Pasted image 20230605110614.png]]

  • 区别
    • 传统NN
      • 总是在同一时刻接受固定量的数据
      • 输出的数量是固定的
      • 没有时效性的概念
    • RNN
      • 不会在同一时刻输入全部数据
      • 数据按时间先后顺序输入模型进行计算,使数据保留了时效性
      • 每一时刻,RNN都会进行一系列计算后再输出结果(隐藏状态)
      • 隐藏状态会和下一时刻的输入进行拼接,用于生成下一时刻的输出
      • 在输入序列结束时,才会输出结果
  • 简单的说,RNN在每一个时间步上可以用上一个时间步的隐藏信息作为输入信息,确保过去的信息可以传递到未来。
  • (注意:RNN的神经单元是可以共享参数)

二、输出模式

输出模式分为两种,一种是只输出最后一个时刻的结果,另一种是每一个时刻的结果都输出。

2.1 只输出最后一个时刻的结果

![[Pasted image 20230605111628.png]]

  • 采用RNN训练序列最后一个神经元作为输出结果
  • 由于此最终输出已经通过所有先前单元格进行了计算,因此已经捕获所有先前输入的上下文(所有输入序列均完成计算)
  • 最终结果取决于之前所有计算和输入
  • 适用于文本分类/判别场景

2.2 每一时刻的结果均输出

![[Pasted image 20230605112713.png]]

  • 采用RNN训练序列每一个神经元获取结果作为输出,(需要经过softmax/Relu作归一化处理之后输出)
  • 在有必要的情况下,也会将每一时刻的输出作为下一时刻的输入
  • 适用于逐字转换的场景

2.3 综合输出模式

综合了以上两种输出模式的综合输出模式。
![[Pasted image 20230605112730.png]]

- 采用RNN训练序列的最后一个神经元作为前半部分的输出,同时作为后半部分的输入。后半部分往往是使用新的训练序列每⼀时刻获取结果作为输出。
  这里类似于[[Transformer模型]],前半部分是encoder,后半部分是decoder。encoder的部分可以选择经过Softmax/sigmoid函数归一化后输出。我们可以用预测文本字典去限制输出为汉语。
- 输入的信息走完全部时间步后,生成的final state作为下游的输入
- 适用于seq2seq场景            [[seq2seq模型]]

三、模型架构和参数

参考资料
一文搞懂RNN(循环神经网络)基础篇 - 知乎 (zhihu.com)
【循环神经网络】5分钟搞懂RNN,3D动画深入浅出_哔哩哔哩_bilibili
【数之道 09】揭开循环神经网络RNN模型的面纱_哔哩哔哩_bilibili

3.1 模型架构

![[Pasted image 20230605144907.png]]

一个简单的RNN由输入层Layer Input,隐藏层Dense hidden,和输出层Layer Output组成。

3.2 模型参数

  1. 输入:
    • Input Tensor
    • Hidden State Tensor(init state 0)
      (可以是input输入也可以是隐藏状态作为输入)
  2. 权重矩阵
    • 输入全连接矩阵 W i W_i Wi
    • 隐藏全连接矩阵 W h W_h Wh
    • 输出全连接矩阵 W o W_o Wo
    • 以上三个矩阵在RNN中的线性相关参数矩阵,在整个RNN网络中都是共享的,之后通过bp算法进行参数的更新,所以体现了RNN循环反馈的思想。
  3. 输出
    • New Hidden State
    • Output Tensor
      ![[Pasted image 20230605145337.png]]

3.3 RN在pytorch中的调用

  1. 激活函数
    首先介绍三个激活函数
    ![[Pasted image 20230605145658.png]]

    显然,sigmoid函数把实数映射到[0,1]的值域上,而tanh函数将实数映射到[-1,1]的值域上,需要保留正负效应时使用tanh函数,而无方向,正负效应则选择sigmoid函数。而Relu函数相差比较大,在x=0处显然不连续,所以可以用于过滤信息的作用。

  2. RNN输出计算公式
    h t = t a n h ( x t W i h T + b i h + h t − 1 W h h T + b h h ) h_t=tanh(x_tW_{ih}^T+b_{ih}+h_{t-1}W_{hh}^T+b_{hh}) ht=tanh(xtWihT+bih+ht1WhhT+bhh)

    • h t h_t ht 是t时刻的隐藏状态
    • x t x_t xt 是t时刻的输入张量
    • h t − 1 h_{t-1} ht1 是t-1时刻的隐藏状态(初始化时为0)默认使用tanh函数激活。(选取非线性param时使用Relu函数)
    • b i h b_{ih} bih 和指的是input to hidden的bias
    • b h h b_{hh} bhh 指的是hidden to hidden dense的bias
    • W i h W_{ih} Wih 指的是input to hidden的权重矩阵
    • W h h W_{hh} Whh 指的是hidden to hidden的权重矩阵

    将公式拆开就可以得到输入向量经权重矩阵变换的公式:
    ![[Pasted image 20230605152239.png]]

    u t = W i ∗ x t u_t=W_i*x_t ut=Wixt
    上一时刻隐藏状态的变换公式: h t ′ = W h ∗ h t − 1 h_t'=W_h*h_{t-1} ht=Whht1
    而隐藏状态需要输出还需要经过Softmax函数进行归一化处理。对当前隐状态进行输出的公式: 输出: o t = S o f t m a x ( h t ∗ W h o T ) 输出:o_t=Softmax(h_t*W_{ho}^T) 输出:ot=Softmax(htWhoT)![[Pasted image 20230605152603.png]]

    RNN(sequence, batch_size, feature)

    • sequence:输入数据序列,按时间顺序输入模型
    • batch_size:表示某时刻下输入信息的个数
    • feature:表示每个信息的特征维度

四、BPTT

Back Propagation Through Time 时序反向传播算法
![[Pasted image 20230605153717.png]]

  • U是与 x t x_t xt 相关的权重矩阵, W i W_i Wi
  • W是RNN隐藏状态的权重矩阵, W h W_h Wh
  • V是RNN生成结果的权重矩阵(before Softmax), W o W_o Wo

从这里开始,我们考虑RNN经过t时刻网络训练后的前向传播和反向传播。
前向传播设计在输入上应用激活函数并最后返回预测结果。

  • 反向传播用于计算和训练网络权重相关的目标函数梯度,最终更新这些参数。
  • 从整体结构上看,RNN有三个可训练的权重矩阵,U( W i W_i Wi) W( W h W_h Wh) V( W o W_o Wo)
  • 所以需要计算目标函数关于三个权重的导数

4.1 反向传播

我们使用 w h w_h wh w o w_o wo 来表示隐藏层和输出层的权重
{ h t = f ( x t , h t − 1 , w i , w h ) o t = g ( h t , o t ) \left\{ \begin{aligned} h_t&=f(x_t,h_{t-1},w_i,w_h)\\ o_t&=g(h_t,o_t) \\ \end{aligned} \right. {htot=f(xt,ht1,wi,wh)=g(ht,ot)
通过一个目标函数L在所有T个时间步内评估输出 o t o_t ot 和对应的标签 y t y_t yt之间的损失, l l l 是时刻t的损失值 L ( x 1 , . . . , x T , y 1 , . . . , y T , w i , w h , w o ) = 1 T ∑ t = 1 T l ( y t , o t ) L(x_1,...,x_T,y_1,...,y_T,w_i,w_h,w_o)=\frac 1 T\sum^T_{t=1}l(y_t,o_t) L(x1,...,xT,y1,...,yT,wi,wh,wo)=T1t=1Tl(yt,ot)

  • 已知目标函数,计算目标函数与可训练矩阵 w h w_h wh 之间的梯度
  • 第一项是损失函数 l l l 对输出函数 o t o_t ot 的梯度
  • 第二项是输出函数 o t o_t ot 对隐状态函数 h t h_t ht 的梯度
  • 第三项是隐状态函数 h t h_t ht对可训练矩阵 w h w_h wh之间的梯度
    ∂ L ∂ w h = 1 T ∑ t = 1 T ∂ l ( y t , o t ) ∂ w h = 1 T ∑ t = 1 T ∂ ( y t , o t ) ∂ o t ∂ g ( h t , w o ) ∂ h t ∂ h t ∂ w h \begin{aligned} \frac{\partial L}{\partial w_h}&= \frac 1 T \sum^T_{t=1}\frac{\partial l(y_t,o_t)}{\partial w_h}\\ &=\frac 1 T\sum^T_{t=1}\frac{\partial(y_t,o_t)}{\partial o_t}\frac{\partial g(h_t,w_o)}{\partial h_t}\frac{\partial h_t}{\partial w_h} \end{aligned} whL=T1t=1Twhl(yt,ot)=T1t=1Tot(yt,ot)htg(ht,wo)whht由于上一时刻的隐状态依然跟 w h w_h wh 相关,所以 ∂ f ( x t , h t − 1 , w h ) ∂ w t \frac{\partial f(x_t,h_{t-1},w_h)}{\partial w_t} wtf(xt,ht1,wh) 应用链式法则求导 ∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h + ∂ f ( x t , h t − 1 , w h ) ∂ w h \frac{\partial h_t}{\partial w_h}=\frac{\partial f(x_t,h_{t-1},w_h)}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial w_h}+\frac{\partial f(x_t,h_{t-1},w_h)}{\partial w_h} whht=ht1f(xt,ht1,wh)whht1+whf(xt,ht1,wh)
    假设存在三个序列 { a t } \{a_t\} {at} , { b t } \{b_t\} {bt} , { c t } \{c_t\} {ct}
    t = 1 , 2 , . . . t = 1,2,... t=1,2,... ,序列满足条件 a 0 = 0 a_0=0 a0=0 a t = b t + c t ∗ a t − 1 a_t=b_t+c_t*a_{t-1} at=bt+ctat1
    t ≥ 1 t \geq 1 t1 时,基于公式进行替换 a t = b t + c t ∗ a t − 1 = b t + c t ∗ ( b t − 1 + c t − 1 ∗ a t − 2 ) = b t + c t ∗ b t − 1 + c t ∗ c t − 1 ∗ a t − 2 = b t + c t ∗ b t − 1 + c t ∗ c t − 1 ∗ ( b t − 2 + c t − 2 ∗ a t − 3 ) = b t + c t ∗ b t − 1 + c t ∗ c t − 1 ∗ b t − 2 + c t ∗ c t − 1 ∗ a t − 3 = . . . = b t + ∑ i = 1 t − 1 ( ∏ j = i + 1 t c j ) b i \begin{aligned} a_t&=b_t+c_t*a_{t-1} \\ &=b_t+c_t*(b_{t-1}+c_{t-1}*a_{t-2}) \\ &=b_t+c_t*b_{t-1}+c_t*c_{t-1}*a_{t-2} \\ &=b_t+c_t*b_{t-1}+c_t*c_{t-1}*(b_{t-2}+c_{t-2}*a_{t-3}) \\ &=b_t+c_t*b_{t-1}+c_t*c_{t-1}*b_{t-2}+c_t*c_{t-1}*a_{t-3} \\ &= \quad ... \\ &=b_t+\sum^{t-1}_{i=1}\big(\prod_{j=i+1}^{t} c_j \big)b_i \end{aligned} at=bt+ctat1=bt+ct(bt1+ct1at2)=bt+ctbt1+ctct1at2=bt+ctbt1+ctct1(bt2+ct2at3)=bt+ctbt1+ctct1bt2+ctct1at3=...=bt+i=1t1(j=i+1tcj)bi
    公式推导完之后,再将这些序列代入 ∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h + ∂ f ( x t , h t − 1 , w h ) ∂ w h a t = ∂ h t ∂ w t b t = ∂ f ( x t , h t − 1 , w h ) ∂ w h c t = ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∑ i = 1 t − 1 ( ∏ j = i + 1 t ∂ f ( x t , h t − 1 , w h ) ∂ h j − 1 ∂ f ( x t , h t − 1 , w h ) ∂ w h ) \begin{aligned} \frac{\partial h_t}{\partial w_h}&=\frac{\partial f(x_t,h_{t-1},w_h)}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial w_h}+\frac{\partial f(x_t,h_{t-1},w_h)}{\partial w_h} \\ \\ a_t&=\frac{\partial h_t}{\partial w_t} \\ b_t&=\frac{\partial f(x_t,h_{t-1},w_h)}{\partial w_h} \\ c_t&=\frac{\partial f(x_t,h_{t-1},w_h)}{\partial h_{t-1}} \\ \\ \frac{\partial h_t}{\partial w_h}&=\frac{\partial f(x_t,h_{t-1},w_h)}{\partial w_h}+\sum^{t-1}_{i=1}\left( \prod^t_{j=i+1}\frac{\partial f(x_t,h_{t-1},w_h)}{\partial h_{j-1}}\frac{\partial f(x_t,h_{t-1},w_h)}{\partial w_h} \right) \end{aligned} whhtatbtctwhht=ht1f(xt,ht1,wh)whht1+whf(xt,ht1,wh)=wtht=whf(xt,ht1,wh)=ht1f(xt,ht1,wh)=whf(xt,ht1,wh)+i=1t1(j=i+1thj1f(xt,ht1,wh)whf(xt,ht1,wh))

五、RNN缺点

缺点很明显,对于长度为T的序列,我们在迭代中计算这T个时间步上的梯度,将会在反向传播过程中产生长度为O(T)的矩阵乘法链。当T的长度较长时,自然而然地会产生梯度爆炸或梯度消失的风险。
![[Pasted image 20230605171304.png]]

  • 改进措施
    1. 计算BPTT时截断时间步
      • 在时间步长 τ \tau τ 之后停止计算梯度和,使用对真实梯度的近似作为结果,在实践中给出不错的结果,但是会导致模型的观测长度变短,导致模型倾向于基于短距离的数据影响,造成bias
    2. Truncated BPTT
      • 使用随机变量来生成变量,控制Truncated的长度(通过步长因子进行调整),TBPTT使用固定长度进行裁剪。类比Top-K容易受数据分布的影响,从而通过动态选择数据分布得到了Top-P⼀样,既然固定长度的剪切会有偏差,那么我们实时更新剪切长度,从而出现了ARTBP。
    3. Anticipated Reweighted Truncated Backpropagation (ARTBP)
      • 兼顾TBPTT的优点,尝试减少模型的bias,但是从实验结果上看,ARTBP并不比TBPTT好多少。因为梯度回传的由于序列过长导致梯度消失和梯度爆炸的根源问题没有解决。
        ![[Pasted image 20230605172153.png]]
  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

安徒生在ACL讲一千零一夜

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值