Seq2Seq的一些概念

Recurrent Neural Network

RNN 又叫做递归神经网络或者循环神经网络,它擅长对序列数据进行建模处理,如时间序列数据,是指在不同时间点上收集的数据,这类数据反映了某一事物、现象在随时间的变化状态或程度,当然这是时间,也可以是文本或图像序列,总的来说,序列数据存在着一个特点——后面的数据跟前面的数据有关系

为什么需要 RNN ?

神经网络结构只能单独的处理一个个的输入,前一个输入与后一个输入是完成没有关系的,但是某些任务需要更好的处理序列信息,即前面一个输入和后面一个输入是要有关系的,通俗点来说就是后一个输入需要记忆前面一个输入的信息

比如,当我们在理解一句话时,孤立的理解单个词是没有意义的,只有将上下词联系起来的整个序列才具有意义;当我们处理视频时,也不能单独分析每一帖,需要分析这些帧连接起来的整个序列

为了解决这一类问题,能够更好的处理序列信息,RNN 模型就应运而生,那么 RNN 又是怎样实现这样的功能呢?

RNN 的结构

RNN 主要对序列数据进行序列处理,其基本结构如下图所示:

上图是 RNN 的结构示意图,每个箭头表示着一次变换,也就是说箭头带有权值,左侧是折叠起来的样子,右侧是展开的样子,左侧 A 旁边的箭头体现着结构中的 “循环” 概念。

在右侧展开结构中我们可以看到,在 x 0 x_0 x0 作为输入时,该单元的输出分为二个方向,向上的 h o h_o ho 表示的是其作为一个输出,向右箭头表示的是其另一个输出作为下一个单元的输入,以此达到与下一个单元之间保持着某个联系,即记忆功能

为了更好的理解,我们看下图:

简单点来说就是:当在 x t x_t xt 时刻时,该单元的输入就分为二个: S t − 1 S_{t-1} St1 x t x_t xt, 输入也分为二个: S t S_{t} St O t O_t Ot

  • S t − 1 S_{t-1} St1 表示的是 x t − 1 x_{t-1} xt1 时刻的一个输出
  • x t x_t xt 表示本时刻的一个输入
  • S t S_{t} St 表示 x t x_t xt 时刻的一个输出,将作为下一时刻的一个输入
  • O t O_t Ot 表示 x t x_t xt 时刻的输出

我们可以用下面的公式来表示 RNN 的计算方式:

  • 上图同样与展现出了 RNN 的另一个特点:权值共享,其中 U 是完全相同的, W、V也是一样的

那么我们再来看看 隐藏层 S 中究竟发生了怎样的变化

我们可以看到 h t − 1 h_{t-1} ht1 x t x_t xt 之间实际上是做了一个 ocncatenate 操作,然后再经过激活函数最终形成了一个输出,值得注意的是它的一个维度变化

Bidirectional RNNs 双向循环神经网络

基本的 RNN 结构只能从之前时间步骤中学习,但是有时我们却需要从未来的时间步骤中学习表示,以便更好地理解上下文环境并消除歧义,通过接下来的列子,“He said, Teddy bears are on sale” and “He said, Teddy Roosevelt was a great President。在上面的两句话中,当我们看到“Teddy”和前两个词“He said”的时候,我们有可能无法理解这个句子是指President还是Teddy bears。因此,为了解决这种歧义性,我们需要往前查找。这就是双向RNN所能实现的。

如图所求,双向 RNN 有二种类型的连接,一种是前向的(Foward RNN),这有助于我们从之前的表示中学习, 另一种是后向的(Backward RNN),这有助于我们从之后的表示中学习

正向传播分为二个步骤:

  1. 我们先从左向右移动,从初始时间步骤开始计算,一直持续到到达最终时间步骤为止

  2. 再从右向左移动,从最后一个时间步骤开始计算,一直持续到到达最终时间步骤为止

  • 一般来说是从前往向计算,再从后往前计算,计算过程相互独立,互不干扰

计算预测输出值就变成了:
y ^ < t > = g ( W y [ a → < t > , a ← < t > ] ) \hat{y}^{<t>}= g(W_y[\overrightarrow{a}^{<t>},\overleftarrow{a}^{<t>}]) y^<t>=g(Wy[a <t>,a <t>])
a → < t > \overrightarrow{a}^{<t>} a <t>表示 Forward RNN 的激活函数, a ← < t > \overleftarrow{a}^{<t>} a <t> 表示 Backward RNN 的激活函数,箭头方向表示的传递方向

梯度消失和梯度爆炸

误差梯度在网络训练中用来得到网络参数的方向和步幅,在正确的方向下以合适的步幅更新网络参数。

梯度爆炸:在递归神经网络中,误差梯度会在更新中累积得到一个非常大的梯度,这样的梯度会大幅更新网络参数,导致网络的不稳定,在极端情况下,权值会变得非常的大以至于结果会溢出(NaN值、无穷或非数值),当梯度爆炸发生时,网络层之间反复乘以大于1.0的值使得梯度值成倍增长

梯度更新:如果误差梯度在更新中累积得到一个非常小的梯度,这也就意味着权值无法更新,最终导致训练失败

利用公式分析原因

经典 RNN 的结构如下图所求:

关于向前传播

假设我们的时间序列只有三段, S 0 S_0 S0 为定值,神经元没有激活函数(便于分析)就可获得各个时间段的状态和输出:
t = 1   时 刻 S 1 = U X 1 + W S 0 + b 1 O 1 = V S 1 + b 2 \begin{aligned}&t = 1 \text{ }时刻\\&S_1 = UX_1 + WS_0 + b_1\\&O_1 = VS_1 + b_2\end{aligned}\\ t=1 S1=UX1+WS0+b1O1=VS1+b2

t = 2   时 刻 S 2 = U X 2 + W S 1 + b 1 O 2 = V S 2 + b 2 \begin{aligned}&t = 2 \text{ }时刻\\&S_2 = UX_2 + WS_1 + b_1\\&O_2 = VS_2 + b_2\end{aligned}\\ t=2 S2=UX2+WS1+b1O2=VS2+b2

t = 3   时 刻 S 3 = U X 3 + W S 2 + b 1 O 3 = V S 3 + b 2 \begin{aligned}&t = 3 \text{ }时刻\\&S_3 = UX_3 + WS_2 + b_1\\&O_3 = VS_3 + b_2\end{aligned}\\ t=3 S3=UX3+WS2+b1O3=VS3+b2

损失函数采用交叉熵 L t = − O t ‾ l o g O t L_t=-\overline{O_t}logO_t Lt=OtlogOt ( O t O_t Ot是 t 时刻的预测输出, O t ‾ \overline{O_t} Ot是 t 时刻的真实输出),那么对于一次训练任务中,损失函数为:
L = ∑ i = 1 T − O t ‾ l o g O t L = \sum_{i=1}^{T}-\overline{O_t}logO_t L=i=1TOtlogOt
T 是序列总长度,上述公式为每一时刻损失值的累加

关于反射传播

我们只对 t3 时时刻的 U、V、W 求偏导,由链式法则可得:
∂ L 3 ∂ V = ∂ L 3 ∂ O 3 ∂ O 3 ∂ V ∂ L 3 ∂ W = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ W + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 2 ∂ S 2 ∂ W + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ W ∂ L 3 ∂ U = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ U + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 2 ∂ S 2 ∂ U + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ U \begin{aligned}&\frac{\partial{L_3}}{\partial{V}} = \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{V}}\\&\frac{\partial{L_3}}{\partial{W}} = \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{W}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{W}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{S_1}}\frac{\partial{S_1}}{\partial{W}}\\&\frac{\partial{L_3}}{\partial{U}} = \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{U}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{U}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{S_1}}\frac{\partial{S_1}}{\partial{U}}\end{aligned} VL3=O3L3VO3WL3=O3L3S3O3WS3+O3L3S2O3WS2+O3L3S3O3S2S3S1S2WS1UL3=O3L3S3O3US3+O3L3S2O3US2+O3L3S3O3S2S3S1S2US1
可以简写成:
∂ L 3 ∂ U = ∑ k = 0 3 ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S k ∂ S k ∂ U = ∑ k = 0 3 ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ( ∏ j = k − 1 3 ∂ S j ∂ S j − 1 ) ∂ S k ∂ U 任 意 时 刻 对 参 数 W 求 偏 导 的 公 式 : ∂ L 3 ∂ W = ∑ k = 0 t ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ( ∏ j = k − 1 t ∂ S j ∂ S j − 1 ) ∂ S k ∂ w \begin{aligned}&\frac{\partial{L_3}}{\partial{U}} = \sum_{k=0}^{3}\frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{S_k}}\frac{\partial{S_k}}{\partial{U}}= \sum_{k=0}^{3}\frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\left( \prod_{j=k-1}^{3}\frac{\partial{S_j}}{\partial{S_{j-1}}} \right)\frac{\partial{S_k}}{\partial{U}}\\&任意时刻对参数 W 求偏导的公式:\\&\frac{\partial{L_3}}{\partial{W}} =\sum_{k=0}^{t}\frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\left( \prod_{j=k-1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} \right)\frac{\partial{S_k}}{\partial{w}}\\\end{aligned} UL3=k=03O3L3S3O3SkS3USk=k=03O3L3S3O3j=k13Sj1SjUSkWWL3=k=0tO3L3S3O3j=k1tSj1SjwSk

由此可以看出 V 求偏导不存在依赖关系,而 W、U则随时间长度存在着长期的依赖关系,因为 S t S_t St 会随着时间序列向前传播,而同时 S t S_t St 是 U、W 的函数

如果取其中的累乘出来,其中激活函数通常是:tanh = [0, 1] 则:

∏ j = k − 1 t ∂ S j ∂ S j − 1 = ∏ j = k − 1 t t a n h ′ W \prod_{j=k-1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} = \prod_{j=k-1}^{t}tanh^{'}W j=k1tSj1Sj=j=k1ttanhW

  • 由上图可以看出 t a n h ′ ∈ [ 0 , 1 ] tanh^{'}\in [0, 1] tanh[0,1] , 也就是说大部分都是 小于1的数在做累乘,假设 W 也是一个大于0小于1的值时,当 t 很大时, ∏ j = k − 1 t t a n h ′ W 公 式 中 的 ∏ j = k − 1 t t a n h ′ \prod_{j=k-1}^{t}tanh^{'}W 公式中的 \prod_{j=k-1}^{t}tanh^{'} j=k1ttanhWj=k1ttanh 部分会趋向于 0,这就是 RNN 中梯度消失的原因
  • 同理, ∏ j = k − 1 t t a n h ′ W 公 式 中 的 \prod_{j=k-1}^{t}tanh^{'}W 公式中的 j=k1ttanhW W 参数很大时,结果就会趋于无穷,这就是产生 梯度爆炸 的原因
解决办法

面对梯度爆炸的问题,我们可以看到梯度爆炸是因为 W 参数的值过大,而 W 值随着序列长度存在长期的依赖关系,因而我们可以设置一个上限值,一旦超过上限值,就等于我们的预设值,这样就可以解决梯度爆炸的问题了

面对梯度消失的问题,梯度消失的原因是 ∏ j = k − 1 t ∂ S j ∂ S j − 1 \prod_{j=k-1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} j=k1tSj1Sj 求导而产生的,因此想要消除这种情况就需要在求领导的时候去掉就行了,那么怎样去掉呢,一般有二种方法:

  • 使 ∂ S j ∂ S j − 1 ≈ 1 \frac{\partial{S_j}}{\partial{S_{j-1}}} \approx 1 Sj1Sj1,那么怎样达到这种目标呢?答案是换一种激活函数,我们来看一下 ReLu 作为激活函数的效果:

​ 可以看到 ReLu 导数在定义域大于0的部分是恒等于1,这样就可以解决梯度消失的问题了

  • 使 ∂ S j ∂ S j − 1 ≈ 0 \frac{\partial{S_j}}{\partial{S_{j-1}}} \approx 0 Sj1Sj0,我们可以采用 LSTM 可以达到这样的效果,那么 LSTM 又是怎样实现的呢,我们在下一篇文章中再来详细解决

参考文献:

[1]. https://www.jiqizhixin.com/articles/2019-01-17-7

声明:

​ 以上内容为个人理解,若有错误,请各位大佬指出,以便大家多作交流!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值