一文彻底搞懂LSTM

网络上大部分介绍LSTM的博客都是先给出一张图,然后进行分析。我们今天反其道行之,分析LSTM先不看图。看图反而会吸引你太多的注意力,无法很好的参透其中的奥妙。

咱们首先明确一下,LSTM有哪几个部分组成:
a) 三个门:输入门、输出门、遗忘门,作用暂且不表,只需要知道门可以控制信息的过滤;
b) Cell state:中文翻译为细胞态,感觉并不是很准确,索性就不翻译了。Cell state变量存储的是当前时刻t及其前面所有时刻的混合信息,也就是说,在LSTM中,信息的记忆与维护都是通过cell state变量的。所以,cell state非常重要!
c) hidden state:当前时刻的隐状态。刚接触LSTM时,一直搞不懂hidden state与cell state的作用。因为,在一般的RNN范式中,只有hidden_state这个概念,为啥LSTM中又多出了个Cell state?在b)中已经提到,cell state维护的是到目前时刻t为止,所有积累的信息。看完具体公式会发现,LSTM中的hidden_state其实就是cell state的一种过滤之后的信息,更关注当前时间点的输出结果。LSTM的hidden state其实就是当前时刻的output。
d) 输入 x t x_t xt。当前时间点的输入。不过,需要注意的是,在LSTM每一个时间步中,最终输入其实由 x t x_t xt及上一时刻隐状态 h t − 1 h_{t-1} ht1组成。

对上述几个部分有个基本概念后,下面再看具体的公式,应该会清晰很多。首先,咱们来看下三个门的计算原理:
输入门: i t = σ ( W i [ x t , h t − 1 ] + b i ) i_t=\sigma(W_i[x_t,h_{t-1}]+b_i) it=σ(Wi[xt,ht1]+bi)
输出门: o t = σ ( W o [ x t , h t − 1 ] + b o ) o_t=\sigma(W_o[x_t,h_{t-1}]+b_o) ot=σ(Wo[xt,ht1]+bo)
遗忘门: f t = σ ( W f [ x t , h t − 1 ] + b f ) f_t=\sigma(W_f[x_t,h_{t-1}]+b_f) ft=σ(Wf[xt,ht1]+bf)
很明显的看出来,三个门的形式一模一样,输入也是一模一样,都是当前时间点输入 x t x_t xt及前一时刻隐状态 h t − 1 h_{t-1} ht1两者的concat向量。唯一不同的可能就是在具体初始化时,权重被设置的不同。

简单看完门如何设置之后,下面就来看看重中之重:Cell state如何通过门进行更新:
当前时间 t t t的Cell state C t C_t Ct有两部分组成,一是继承自上一时刻的Cell state C t − 1 C_{t-1} Ct1,二是来源于当前时刻 t t t的输入。

在继承上一时刻 C t − 1 C_{t-1} Ct1时,LSTM通过遗忘门 f t f_t ft去控制到底继承多少历史信息,即 C t − p a r t 1 = f t ⨀ C t − 1 C_{t-part1}=f_t\bigodot C_{t-1} Ctpart1=ftCt1

将当前时刻 t t t的输入信息融入进 C t C_t Ct时,LSTM通过输入门 i t i_t it来控制到底有多少输入会被加进 C t C_t Ct中。没有输入门的话,所有的输入都会被加进 C t C_t Ct中,即 t a n h ( W c [ x t , h t − 1 ] + b t ) tanh(W_c[x_t,h_{t-1}]+b_t) tanh(Wc[xt,ht1]+bt)。但是,输入门的作用就是有选择性的而不是一股脑将输入都加进 C t C_t Ct中。所以, C t − p a r t 2 = i t ⨀ t a n h ( W c [ x t , h t − 1 ] + b c ) C_{t-part2}=i_t\bigodot tanh(W_c[x_t,h_{t-1}]+b_c) Ctpart2=ittanh(Wc[xt,ht1]+bc)

最终,当前时刻 t t t C t = C t − p a r t 1 + c t − p a r t 2 = f t ⨀ C t − 1 + i t ⨀ t a n h ( W c [ x t , h t − 1 ] + b c ) C_t=C_{t-part1}+c_{t-part2}=f_t\bigodot C_{t-1} + i_t\bigodot tanh(W_c[x_t,h_{t-1}]+b_c) Ct=Ctpart1+ctpart2=ftCt1+ittanh(Wc[xt,ht1]+bc)。公式看起来很复杂是不是,但是只要将上面组成 C t C_t Ct的两部分分开来分析,这个公式就没那么复杂了。

截止到这里, C t C_t Ct已经完成了当前时刻 t t t的信息更新,那么当前时刻的输出是什么呢?LSTM中的输出是根据 C t C_t Ct,并利用输出门去控制哪些信息作为输出的,即 h t = o t ⨀ t a n h ( C t ) h_t=o_t\bigodot tanh(C_t) ht=ottanh(Ct)。也就是说,LSTM每一时刻输出的基础是 C t C_t Ct,只不过套了一层激活函数,并且利用输出门控制了一下信息流出。
值得注意的是, ⨀ \bigodot 表示element-wise product。所以所有门的维度和cell state变量是一致的
至此,LSTM的基本概念已经介绍完了,下面再来看这张经典的网络结构图,会不会清晰很多。
LSTM原理图

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值