一般来说,全连接层和卷积层已经可以处理大部分的情况了,而RNN的出现,主要是针对两个问题,第一,处理变长的输入,第二,分析序列的顺序信息。虽然目前我们可以通过空间金字塔池化搭配卷积网络实现不定长度序列的处理分析,可是池化操作会丢失输入的顺序信息,所以RNN还是有他的作用的,而且他的结构是如此的简单和巧妙,所以这次我就想先回顾一下RNN,然后详细探讨一下它的长期依赖问题,最后再分析LSTM到底为什么能改善这个问题。
我们可以先回顾一下如何基于RNN的思想一步步推导出公式。RNN的主要思想在于当前时刻的状态和当前时刻的输入以及上一时刻的状态有关,首先我们可以有公式:
S ( t ) = f ( W x t + U S ( t − 1 ) ) S(t) = f(Wx_{t} + US(t-1)) S(t)=f(Wxt+US(t−1))
上式其实就是先对上一时刻的状态和当前时刻的输入进行一个线性组合,再输入到激活函数中,计算出当前时刻的状态S(t),一方面他可以传给下一时刻计算,另一方面他可以用来计算当前时刻的输出:
o u t p u t = g ( V S ( t ) ) output = g(VS(t)) output=g(VS(t))
对输出来说,其实就是先对当前时刻的状态进行线性变换,再输入到激活函数。
RNN最重要的结构,在于:
W x t + U S ( t − 1 ) Wx_{t} + US(t-1) Wxt+US(t−1)
也就是对上一时刻的状态和当前时刻的输入同时进行分析,这里选择了线性变换的方法,我觉得不论使用什么方法问题不大,比如你可以选择把上一时刻的状态和当前时刻的输入相乘、上一时刻的平方和当前时刻的平方相加等等,重点在于要把两者联系上,这样,在反向传播的时候,模型才能同时考虑两者,达到所谓的"综合考虑上一时刻的状态和当前时刻的输入"这个效果。
综上所述,RNN最重要的创新就是提出了要把上一时刻的状态和当前时刻的输入联系在一起,从而计算当前的输出。
还记得一开始提到RNN能处理任意长度的输入,从这一点就看出RNN的模型参数必须是共享的,也就是上面提到的所有W、U等等参数都是一样的,不然你输入更长的序列就需要更多的参数,那还怎么做预测。
简单回顾了一下RNN的结构,接下来就是重点的问题,为什么RNN会存在长期依赖的问题,我看了不少文章,感觉都说得不是很清楚,有人从sigmoid、tanh的角度出发进行分析,但是我就在想,如果这真的是根本的原因,那岂不是改变激活函数就能解决了,所以我在这里尝试通过公式推导,对广义上的RNN进行分析(不指定激活函数)。
首先我们设定状态0:
S 0 ^ = x 0 \hat {S_0} = x_0 S0^=x0
然后我们写出三个时刻的状态:
S 1 ^ = F ( w x 1 + u S 0 ^ ) \hat {S_1} = F(wx_1+u \hat {S_0}) S1^=F(wx1+uS0^)
S 2 ^ = F ( w x 2 + u S 1 ^ ) \hat {S_2} = F(wx_2+u \hat {S_1}) S2^=F(wx2+uS1^)
S 3 ^ = F ( w x 3 + u S 2 ^ ) \hat {S_3} = F(wx_3+u \hat {S_2}) S3^=F(wx3+uS2^)
把公式展开:
S 1 ^ = F ( w x 1 + u x 0 ) \hat {S_1} = F(wx_1+ux_0) S1^=F(wx1+ux0)
S 2 ^ = F ( w x 2 + u F ( w x 1 + u x 0 ) ) \hat {S_2} = F(wx_2+uF(wx_1+ux_0)) S2^=F(wx2+uF(wx1+ux0))
S 3 ^ = F ( w x 3 + u F ( w x 2 + u F ( w x 1 + u x 0 ) ) ) \hat {S_3} = F(wx_3+uF(wx_2+uF(wx_1+ux_0))) S3^=F(wx3+uF(wx2+uF(wx1+ux0)))
为了后续的分析更简单易懂,这里我选择欧氏距离的绝对值作为损失函数:
L O S S i = ∣ S i ^ − S i ∣ LOSS_i = |\hat {S_i} - S_i| LOSSi=∣Si^−Si∣
L O S S i = S i ^ − S i i f S i ^ > S i LOSS_i = \hat {S_i} - S_i \ \ \ \ if \ \ \hat {S_i} > S_i LOSSi=Si^−Si if Si^>Si
L O S S i = − S i ^ + S i i f S i ^ < S i LOSS_i = -\hat {S_i} + S_i \ \ \ \ if \ \ \hat {S_i} < S_i LOSSi=−Si^+Si if Si^<Si
L O S S = ∑ i = 1 n L O S S i LOSS = \sum _{i=1} ^n LOSS_i LOSS=i=1∑nLOSSi
这样,求导的时候实际上就是对S_i求导,只是根据真实值改变正负号,所以我们可以直接分析每个状态的求导,先看看状态1的偏导:
∂ S 1 ^ ∂ w = ∂ F ∂ w \frac{\partial \hat {S_1}}{\partial w} = \frac{\partial F}{\partial w} ∂w∂S1^=∂w∂F
∂ S 1 ^ ∂ w = ∂ F ( w x 1 + u x 0 ) ∂ ( w x 1 + u x 0 ) ∂ ( w x 1 + u x 0 ) ∂ w = ∂ F ( w x 1 + u x 0 ) ∂ ( w x 1 + u x 0 ) x 1 \frac{\partial \hat {S_1}}{\partial w} = \frac{\partial F(wx_1+ux_0)}{\partial (wx_1+ux_0)} \frac{\partial (wx_1+ux_0)}{\partial w} = \frac{\partial F(wx_1+ux_0)}{\partial (wx_1+ux_0)} x_1 ∂w∂S1^=∂(wx1+ux0)∂F(wx1+ux0)∂w∂(wx1+ux0)=∂(wx1+ux0)∂F(wx1+ux0)x1
继续看状态2:
∂ S 2 ^ ∂ w = ∂ F ( w x 2 + u F ( w x 1 + u x 0 ) ) ∂ w \frac{\partial \hat {S_2}}{\partial w} = \frac{\partial F(wx_2+uF(wx_1+ux_0))}{\partial w} ∂w∂S2^=∂w∂F(wx2+uF(wx1+ux0))
∂ S 2 ^ ∂ w = ∂ F ( w x 2 + u F ( w x 1 + u x 0 ) ) ∂ ( w x 2 + u F ( w x 1 + u x 0 ) ) ∂ ( w x 2 + u F ( w x 1 + u x 0 ) ) ∂ w \frac{\partial \hat {S_2}}{\partial w} = \frac{\partial F(wx_2+uF(wx_1+ux_0))}{\partial (wx_2+uF(wx_1+ux_0))} \frac{\partial (wx_2+uF(wx_1+ux_0))}{\partial w } ∂w∂S2^=∂(wx2+uF(wx1+ux0))∂F(wx2+uF(wx1+ux0))∂w∂(wx2+uF(wx1+ux0))
∂ S 2 ^ ∂ w = ∂ F ( w x 2 + u F ( w x 1 + u x 0 ) ) ∂ ( w x 2 + u F ( w x 1 + u x 0 ) ) ( x 2 + u ∂ ( F ( w x 1 + u x 0 ) ) ∂ w ) \frac{\partial \hat {S_2}}{\partial w} = \frac{\partial F(wx_2+uF(wx_1+ux_0))}{\partial (wx_2+uF(wx_1+ux_0))} (x2+ u \frac{\partial (F(wx_1+ux_0))}{\partial w } ) ∂w∂S2^=∂(wx2+uF(wx1+ux0))∂F(wx2+uF(wx1+ux0))(x2+u∂w∂(F(wx1+ux0)))
∂ S 2 ^ ∂ w = ∂ F ( w x 2 + u F ( w x 1 + u x 0 ) ) ∂ ( w x 2 + u F ( w x 1 + u x 0 ) ) ( x 2 + u ∂ F ( w x 1 + u x 0 ) ∂ ( w x 1 + u x 0 ) x 1 ) \frac{\partial \hat {S_2}}{\partial w} = \frac{\partial F(wx_2+uF(wx_1+ux_0))}{\partial (wx_2+uF(wx_1+ux_0))} (x2+ u \frac{\partial F(wx_1+ux_0)}{\partial (wx_1+ux_0)} x_1 ) ∂w∂S2^=∂(wx2+uF(wx1+ux0))∂F(wx2+uF(wx1+ux0))(x2+u∂(wx1+ux0)∂F(wx1+ux0)x1)
∂ S 2 ^ ∂ w = ∂ F ( w x 2 + u F ( w x 1 + u x 0 ) ) ∂ ( w x 2 + u F ( w x 1 + u x 0 ) ) x 2 + u ∂ F ( w x 2 + u F ( w x 1 + u x 0 ) ) ∂ ( w x 2 + u F ( w x 1 + u x 0 ) ) ∂ F ( w x 1 + u x 0 ) ∂ ( w x 1 + u x 0 ) x 1 \frac{\partial \hat {S_2}}{\partial w} = \frac{\partial F(wx_2+uF(wx_1+ux_0))}{\partial (wx_2+uF(wx_1+ux_0))} x2 + u \frac{\partial F(wx_2+uF(wx_1+ux_0))}{\partial (wx_2+uF(wx_1+ux_0))} \frac{\partial F(wx_1+ux_0)}{\partial (wx_1+ux_0)} x_1 ∂w∂S2^=∂(wx2+uF(wx1+ux0))∂F(wx2+uF(wx1+ux0))x2+u∂(wx2+uF(wx1+ux0))∂F(wx2+uF(wx1+ux0))∂(wx1+ux0)∂F(wx1+ux0)x1
接下来由于空间不足就不继续写出状态3的偏导了,不过从上式也足够看出原因了,对于状态2中x1,因为链式求导的关系需要乘上2个函数在对应点的导数的乘积,对于状态3,x1则需要乘上3个函数的导数乘积,以此类推,越往后x1的乘积就越多,造成的后果就是,假如导数大于1,那么乘积就会指数增长,造成梯度非常大,也就是所谓的梯度爆炸,如果导数小于1,乘积就会指数减少到0附近,这时候就认为x1对于偏后的状态(比如S100)的反向传播帮助小(不是梯度消失,不一定消失,因为模型的梯度不仅仅只有x1这一项,还有x2、x3等等),这也就是一直说的长期依赖的问题。
从上面的分析就可以看出,长期依赖的问题是RNN的结构造成的,因为他每个状态都包含了上个状态的输出,所以在反向传播的时候就必须进行链式求导,就必然导致了长期依赖的问题,这是没有办法从根本上解决的,但是还是可以改善的,所以接下来就重点分析LSTM为什么能改善RNN的长期依赖问题。
我认为LSTM的核心就是他的长期记忆,所谓长期记忆,就是一个向量(或者矩阵),没有任何的参数,所以它不会参与到模型的反向传播当中,所以下面就从长期依赖(链式求导)的角度出发,看看长期记忆是怎么协助模型分析的。
以前的话,我喜欢从感性的角度进行分析,比如说"分析长期记忆中什么信息需要遗忘,需要追加什么新的记忆",但是我现在倒认为,这些终究是我们自己强加的理解,归根到底,我们要把模型看成一个复杂的公式,从数学的角度、从抽象的角度进行分析。
长期记忆能发挥作用,缓和长期依赖的问题,重点就在于它和RNN之间存在交互,使得模型在反向传播中能够间接对长期记忆进行调整。
上图中的乘法和加法就是RNN对长期记忆之间的交互,模型上一时刻的不同、当前时刻输入的不同,都会导致长期记忆的不同改变,当然,如果只有RNN对长期记忆的改变是不够的,所以LSTM最后也引入了长期记忆对RNN输出的影响:
可以看到,调整之后的长期记忆,又会反过来影响ht-1和xt,从而计算出该时刻的状态ht,并最终基于此状态计算时刻t的输出。
综上所述,LSTM结构虽然看起来很复杂,但是最核心的结构在于长期记忆,它既能影响每个时刻的输出,又可以根据上一时刻的状态和当前时刻的输入进行调整,而且还不受反向传播的影响,不得不承认,真的很妙。
最后作为总结,我想说一下结构之间的联系的重要性,其实看回RNN和LSTM,我都比较强调他们内在结构的联系,RNN通过加法把不同的时刻的状态联系在一起,LSTM也是通过加法乘法把RNN和长期记忆联系在一起,虽然加法乘法只是一种很简单的运算,但是因为模型可以反向传播,学习参数,同时模型的各个模块之间是相互联系的,使得模型可以对整体进行调整。