使神经网络具有记忆力——RNN及LSTM

我们在进行判断决策时,除了会依据当前的情况,也会调动大脑中的记忆,协同分析。记忆分为长期记忆和短期记忆,短期记忆可以认为是对之前较短时间内发生事件的印象,这对于一些日常生活应用非常的有必要。比如以下两个句子中,“我将在9月10日到达南京”和“我将在9月10日离开南京”,两句话虽然都包含“南京”,但是第一句话中“南京”是目的地,第二句话中“南京”是出发地,做出这个判断的依据是“南京”之前的“到达”和“离开”,我们之所以能够做出准确的判断,是因为我们在读完前面的内容后会在脑海中形成一个短期的记忆,以此支持我们的判断。但传统神经网络对此无能为力,因为在每一时刻,神经网络的输入只有当前的数据,而不依赖之前的处理结果,因此传统神经网络对时序数据束手无策,这大大限制了机器学习模型的能力和应用场景。
神经网络因此需要记忆力!Neural network needs memory!

Recurrent Neural Network(循环神经网络,RNN)

循环神经网络(Recurrent Neural Network,RNN)是一类具有记忆能力,能够处理时序数据的神经网络模型的总称。最简单的循环神经网络直接将前一时刻的输出同时作为当前时刻的输入,模型示意图如下:
在这里插入图片描述

其中,右半部分就是一个常规的神经网络,其隐含层数目没有明确规定,可以根据任务具体设定, x x x y y y分别表示网络的输入和输出。循环神经网络最大的特点是:其具有一个记忆模块(图中左边的 a 1 a_1 a1 a 2 a_2 a2),用来存储前一时刻的网络信息(隐含层输出甚至是输出层输出),在下一时刻,记忆模块中的内容连同输入层数据一起作为网络输入。在时间维度上将神经网络展开如下图所示:
在这里插入图片描述

其中, U U U V V V W W W分别是输入层权重,输出层权重和隐含层权重,在每个时刻网络共享同一套参数,所以RNN计算中如下:
o t = σ ( V s t ) s t = σ ( U x t + W s t − 1 ) \begin{aligned} o_{t} =& \sigma(Vs_t)\\ s_t =& \sigma(Ux_t+Ws_{t-1}) \end{aligned} ot=st=σ(Vst)σ(Uxt+Wst1)

可以看到每一时刻 t t t隐含层的输入是当前时刻网络输入 x t x_t xt和前一时刻隐含层输出 s t − 1 s_{t-1} st1的加权和,并且每一时刻网络使用的都是同一个网络结构。现在讨论一个具体情况,网络结构如上模型示意图所示,假设所有的权重值都为1且不考虑偏置,激活函数采用简单的线性函数,输入序列分别是 [ 1 , 1 ] T [1,1]^T [1,1]T [ 1 , 1 ] T [1,1]^T [1,1]T [ 2 , 2 ] T [2,2]^T [2,2]T [ 1 , 1 ] T [1,1]^T [1,1]T [ 2 , 2 ] T [2,2]^T [2,2]T [ 1 , 1 ] T [1,1]^T [1,1]T。前者的输出是 [ 4 , 4 ] T [4,4]^T [4,4]T [ 12 , 12 ] T [12,12]^T [12,12]T [ 32 , 32 ] T [32,32]^T [32,32]T,后者的输出是 [ 4 , 4 ] T [4,4]^T [4,4]T [ 16 , 16 ] T [16,16]^T [16,16]T [ 36 , 36 ] T [36,36]^T [36,36]T。可以看到当输入序列的顺序改变后,输出也将随之变化,这也符合我们的思维,序列数据中,数据之间是存在关联的。但是传统神经网络,将序列数据看作一个孤立的点,没有考虑数据之间的相互影响。

RNN的变形

RNN存在多种形式,根据记忆模块信息的来源可以分为Elman Network和Jordan Network,Elman Network将隐含层的输出作为记忆内容传给下一时刻,而Jordan Network则将输出层的输出作为记忆内容传给下一时刻。如下图所示:
在这里插入图片描述
根据输入输出序列的长度可以将RNN分为一对一,一对多,多对一,多对多(时序对齐),多对多(时序不齐),如下图:
在这里插入图片描述

一对一:退化为传统神经网络;
一对多:图片语义标注(图片 → \to 文字序列);
多对一:文本情感分类(文字序列 → \to 情感值);
多对多(时序不齐):机器翻译(文字序列 → \to 文字序列);
多对多(时序对齐):视频帧分类(视频帧序列 → \to 分类值序列)。
目前的RNN都只依赖当前时刻之前的序列信息,记忆从前向后单向传递。然而,在有些应用中,当前时刻的决策不仅需要考虑前序信息也需要考虑后序信息,例如单词填空,不仅需要根据前文内容,也要考虑后文内容。因此同样需要将当前时刻之后的序列信息输入当前时刻网络,即记忆也要从后向前反向传递。这一类RNN被称为双向RNN(Bidirectional RNN)。
这里写图片描述

可以看到,双向RNN由一个正向RNN和一个反向RNN组成,正向RNN从前向后读入输入序列,反向RNN从后向前读入输入序列,每一时刻的网络输出由正向RNN的输出和反向RNN的输出共同决定(可以是拼接也可以是求和)。

RNN的计算图

多对RNN的计算图如下图所示,在时间维度上展开,其梯度计算过程同样可以使用反向传播算法求解,叫做BPTT(Backpropagation Through Time)。
在这里插入图片描述

输出层权重 V V V在反向传播过程中不会随时间传递,因此只需要将所有 V V V节点的求导结果相加即可,
∂ C ∂ V = ∑ k = 1 T ∂ C k ∂ y k ∂ y k ∂ V \frac{\partial C}{\partial V}=\sum_{k=1}^T\frac{\partial C_k}{\partial y_k}\frac{\partial y_k}{\partial V}\\ VC=k=1TykCkVyk

输入层权重 U U U、隐含层权重 W W W在反向传播过程中都会随时间传递,以输入层权重 U U U在时刻3为例,
∂ C 3 ∂ U = ∂ C 3 ∂ y 3 ∂ y 3 ∂ h 3 ∂ h 3 ∂ U + ∂ C 3 ∂ y 3 ∂ y 3 ∂ h 3 ∂ h 3 ∂ h 2 ∂ h 2 ∂ U + ∂ C 3 ∂ y 3 ∂ y 3 ∂ h 3 ∂ h 3 ∂ h 2 ∂ h 2 ∂ h 1 ∂ h 1 ∂ U = ∑ k = 1 3 ∂ C 3 ∂ y 3 ∂ y 3 ∂ h 3 ( ∏ j = k + 1 3 ∂ h j ∂ h j − 1 ) ∂ h k ∂ U ∂ C ∂ U = ∑ t = 1 T ( ∑ k = 1 t ∂ C t ∂ y t ∂ y t ∂ h t ( ∏ j = k + 1 t ∂ h j ∂ h j − 1 ) ∂ h k ∂ U ) \begin{aligned} \frac{\partial C_3}{\partial U} =& \frac{\partial C_3}{\partial y_3}\frac{\partial y_3}{\partial h_3}\frac{\partial h_3}{\partial U}+\frac{\partial C_3}{\partial y_3}\frac{\partial y_3}{\partial h_3}\frac{\partial h_3}{\partial h_2}\frac{\partial h_2}{\partial U}+\frac{\partial C_3}{\partial y_3}\frac{\partial y_3}{\partial h_3}\frac{\partial h_3}{\partial h_2}\frac{\partial h_2}{\partial h_1}\frac{\partial h_1}{\partial U}\\ =&\sum_{k=1}^3\frac{\partial C_3}{\partial y_3}\frac{\partial y_3}{\partial h_3}\Big(\prod_{j=k+1}^3\frac{\partial h_j}{\partial h_{j-1}}\Big)\frac{\partial h_k}{\partial U}\\ \frac{\partial C}{\partial U} =& \sum_{t=1}^T\bigg(\sum_{k=1}^t\frac{\partial C_t}{\partial y_t}\frac{\partial y_t}{\partial h_t}\Big(\prod_{j=k+1}^t\frac{\partial h_j}{\partial h_{j-1}}\Big)\frac{\partial h_k}{\partial U}\bigg) \end{aligned} UC3==UC=y3C3h3y3Uh3+y3C3h3y3h2h3Uh2+y3C3h3y3h2h3h1h2Uh1k=13y3C3h3y3(j=k+13hj1hj)Uhkt=1T(k=1tytCthtyt(j=k+1thj1hj)Uhk)

同理,隐含层权重 W W W的偏导数如下:
∂ C ∂ U = ∑ t = 1 T ( ∑ k = 1 t ∂ C t ∂ y t ∂ y t ∂ h t ( ∏ j = k + 1 t ∂ h j ∂ h j − 1 ) ∂ h k ∂ W ) \frac{\partial C}{\partial U}=\sum_{t=1}^T\bigg(\sum_{k=1}^t\frac{\partial C_t}{\partial y_t}\frac{\partial y_t}{\partial h_t}\Big(\prod_{j=k+1}^t\frac{\partial h_j}{\partial h_{j-1}}\Big)\frac{\partial h_k}{\partial W}\bigg) UC=t=1T(k=1tytCthtyt(j=k+1thj1hj)Whk)

其中 h j = σ ( W h j − 1 + U x t ) h_j=\sigma(Wh_{j-1}+Ux_t) hj=σ(Whj1+Uxt),当前时刻的记忆 h j h_j hj与上一时刻的记忆 h j − 1 h_{j-1} hj1是累乘关系(覆写),这种连乘关系使得记忆衰减的特别快。在反向传播求导时 ∂ h j ∂ h j − 1 = σ ′ W \frac{\partial h_j}{\partial h_{j-1}}=\sigma^{'}W hj1hj=σW σ ′ \sigma^{'} σ的值通常都小于1,多次相乘会越来越小,如果 W W W也是一个小于1的值,则多次相乘其值就会趋近于0,即出现梯度消失问题,因此很难学习到远距离依赖关系。当然也会存在梯度爆炸问题,只不过这种情况通常发生的可能性小很多。

Long Short-Term Memory(LSTM)

由于Vanilla RNN存在梯度消失问题,对长距离依赖关系(Long-Term Dependencies)的建模能力有限。因为Vanilla RNN在设计时上一时刻的记忆和当前时刻的记忆是相乘关系,这种结构上的限制导致较长时间前的记忆,对当前时刻网络状态影响非常小,在反向传播时,那些梯度也因而很难影响到较长时间前的输入,即梯度消失问题。解决这个问题的思路是尝试让网络记住那些非常重要的信息,最naive的想法是让之前的记忆和当前输入信息相加作为当前的记忆(累加)。Long Short-Term Memory(LSTM)就是这么一种网络,其记忆模块称作Cell,通过控制三个门(gate)的状态来指导Cell中记忆的写入、遗忘、输出。其结构如下图所示:
在这里插入图片描述

三个门分别是输入门(Input Gate)、遗忘门(Forget Gate)和输出门(Output Gate),输入门用来控制信息进入Cell,遗忘门用来控制Cell中信息的更新,输出门用来控制Cell中信息的输出,每个门的状态由函数 f f f控制。该模块包含4个输入1个输出,其中4个输入的输入值都相同,数据传递过程如下:
i t = σ f ( W i x t + U i h t − 1 + b i ) f t = σ f ( W f x t + U f h t − 1 + b f ) c t = f t ∗ c t − 1 + i t ∗ σ g ( W c x t + U c h t − 1 + b c ) o t = σ f ( W o x t + U o h t − 1 + b o ) h t = o t ∗ σ h ( c t ) \begin{aligned} i_t =& \sigma_f(W_ix_t+U_ih_{t-1}+b_i)\\ f_t =& \sigma_f(W_fx_t+U_fh_{t-1}+b_f)\\ c_t =& f_t*c_{t-1}+i_t*\sigma_g(W_cx_t+U_ch_{t-1}+b_c)\\ o_t =& \sigma_f(W_ox_t+U_oh_{t-1}+b_o)\\ h_t =& o_t*\sigma_h(c_t) \end{aligned} it=ft=ct=ot=ht=σf(Wixt+Uiht1+bi)σf(Wfxt+Ufht1+bf)ftct1+itσg(Wcxt+Ucht1+bc)σf(Woxt+Uoht1+bo)otσh(ct)

上面公式按照数据从下至上传递过程给出,输入都是当前时刻的输入层 x t x_t xt输入和上一时刻Cell单元输出 h t − 1 h_{t-1} ht1 i t i_t it控制输入门状态,当其值为1时,表示记住全部输入信息;当其值为0时,表示不记住输入信息。 f t f_t ft控制遗忘门状态,当其值为1时,表示保留之前全部信息;当其值为0事,表示清空之前信息。所以当前Cell中保存的信息由当前输入信息和保留历史信息共同决定。 0 t 0_t 0t控制输出门的状态,当其值为1时,表示将Cell中信息全部输出;当其值为0时,表示不将Cell中信息输出。
那LSTM是怎么解决Vanilla RNN梯度消失的问题的呢?我们知道RNN梯度消失的罪魁祸首是 ∏ j = k + 1 t ∂ h j ∂ h j − 1 = ∏ j = k + 1 t σ ′ W \prod_{j=k+1}^t\frac{\partial h_j}{\partial h_{j-1}}=\prod_{j=k+1}^t\sigma^{'}W j=k+1thj1hj=j=k+1tσW。LSTM中同样存在记忆由上一时刻传递到下一时刻,因此对其求导 ∂ c j ∂ c j − 1 = f t + ⋯ \frac{\partial c_j}{\partial c_{j-1}}=f_t+\cdots cj1cj=ft+ f t f_t ft就是遗忘门的输出值,当 f t = 1 f_t=1 ft=1时,梯度将很好的传递到上一时刻;当 f t = 0 f_t=0 ft=0时,即之前的记忆已经被全部遗忘,不会影响到当前时刻,那么梯度也不需要传递回去。
上述LSTM是将上一时刻Cell的输出 h h h作为下一时刻的输入,LSTM还存在其他形式。
1)将上一时刻Cell中保存的值 c c c作为下一时刻的输入,因此此时输入是 [ x t , c t − 1 ] [x_t,c_{t-1}] [xt,ct1],计算公式是:
i t = σ f ( W i x t + U i c t − 1 + b i ) f t = σ f ( W f x t + U f c t − 1 + b f ) c t = f t ∗ c t − 1 + i t ∗ σ g ( W c x t + U c c t − 1 + b c ) o t = σ f ( W o x t + U o c t − 1 + b o ) h t = o t ∗ σ h ( c t ) \begin{aligned} i_t =& \sigma_f(W_ix_t+U_ic_{t-1}+b_i)\\ f_t =& \sigma_f(W_fx_t+U_fc_{t-1}+b_f)\\ c_t =& f_t*c_{t-1}+i_t*\sigma_g(W_cx_t+U_cc_{t-1}+b_c)\\ o_t =& \sigma_f(W_ox_t+U_oc_{t-1}+b_o)\\ h_t =& o_t*\sigma_h(c_t) \end{aligned} it=ft=ct=ot=ht=σf(Wixt+Uict1+bi)σf(Wfxt+Ufct1+bf)ftct1+itσg(Wcxt+Ucct1+bc)σf(Woxt+Uoct1+bo)otσh(ct)

2)上一时刻的Cell中内容 c t − 1 c_{t-1} ct1、LSTM输出内容 h t − 1 h_{t-1} ht1都输入下一时刻,即输入为 [ x t , h t − 1 , c t − 1 ] [x_t,h_{t-1},c_{t-1}] [xt,ht1,ct1],计算公式类似,在此不再赘述。
明确了LSTM的内部结构后,我们可以将当作一个新的neural,替换Vanilla RNN中的隐含层神经元。如下图所示:
在这里插入图片描述

注意在LSTM输入 [ x t , h t − 1 ] [x_t,h_{t-1}] [xt,ht1]时采用了双线箭头,代表LSTM需要使用输入多次(4次),因此LSTM相较Vanilla RNN参数使用增加了4倍。再进一步将LSTM内部结构细化,如下图:
在这里插入图片描述

总结

RNN的提出解决了传统神经网络没有记忆能力,无法处理序列数据的缺点,推动了深度学习的快速发展,目前RNN已经成为语言建模、机器翻译、文本处理、图片描述等任务的标配,并且取得了一定的成功。更多更丰富有趣的RNN成功应用请参考The Unreasonable Effectiveness of Recurrent Neural Networks

参考文献

李宏毅主页
RNN, LSTM, GRU, SRU, Multi-Dimensional LSTM, Grid LSTM, Graph LSTM系列解读
LSTM算法原理简介及Tutorial
LSTM如何解决梯度消失问题

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值