LSTM神经网络

原创 2017年08月03日 21:21:52

LSTM是什么

LSTM即Long Short Memory Network,长短时记忆网络。它其实是属于RNN的一种变种,可以说它是为了克服RNN无法很好处理远距离依赖而提出的。

我们说RNN不能处理距离较远的序列是因为训练时很有可能会出现梯度消失,即通过下面的公式训练时很可能会发生指数缩小,让RNN失去了对较远时刻的感知能力。

EW=tEtW=tk=0Etnettnettst(tj=k+1stsk)skW

解决思路

RNN梯度消失不应该是由我们学习怎么去避免,而应该通过改良让循环神经网络自己具备避免梯度消失的特性,从而让循环神经网络自身具备处理长期序列依赖的能力。

RNN的状态计算公式为St=f(St1,xt),根据链式求导法则会导致梯度变为连乘的形式,而sigmoid小于1会让连乘小得很快。为了解决这个问题,科学家采用了累加的形式,St=tτ=1ΔSτ,其导数也为累加,从而避免梯度消失。LSTM即是使用了累加形式,但它的实现较复杂,下面进行介绍。

LSTM模型

回顾一下RNN的模型,如下图,展开后多个时刻隐层互相连接,而所有循环神经网络都有一个重复的网络模块,RNN的重复网络模块很简单,如下下图,比如只有一个tanh层。
这里写图片描述

这里写图片描述

而LSTM的重复网络模块的结构则复杂很多,它实现了三个门计算,即遗忘门、输入门和输出门。每个门负责是事情不一样,遗忘门负责决定保留多少上一时刻的单元状态到当前时刻的单元状态;输入门负责决定保留多少当前时刻的输入到当前时刻的单元状态;输出门负责决定当前时刻的单元状态有多少输出。

这里写图片描述

每个LSTM包含了三个输入,即上时刻的单元状态、上时刻LSTM的输出和当前时刻输入。

LSTM的机制

这里写图片描述

根据上图咱们一步一步来看LSTM神经网络是怎么运作的。

首先看遗忘门,用来计算哪些信息需要忘记,通过sigmoid处理后为0到1的值,1表示全部保留,0表示全部忘记,于是有

ft=σ(Wf[ht1,xt]+bf)

其中中括号表示两个向量相连合并,Wf是遗忘门的权重矩阵,σ为sigmoid函数,bf为遗忘门的偏置项。设输入层维度为dx,隐藏层维度为dh,上面的状态维度为dc,则Wf的维度为dc×(dh+dx)

这里写图片描述

其次看输入门,输入门用来计算哪些信息保存到状态单元中,分两部分,第一部分为

it=σ(Wi[ht1,xt]+bi)

该部分可以看成当前输入有多少是需要保存到单元状态的。第二部分为

c~t=tanh(Wc[ht1,xt]+bc)

该部分可以看成当前输入产生的新信息来添加到单元状态中。结合这两部分来创建一个新记忆。

这里写图片描述

而当前时刻的单元状态由遗忘门输入和上一时刻状态的积加上输入门两部分的积,即

ct=ftct1+itc~t

这里写图片描述

最后看看输出门,通过sigmoid函数计算需要输出哪些信息,再乘以当前单元状态通过tanh函数的值,得到输出。

ot=σ(Wo[ht1,xt]+bo)

ht=ottanh(ct)

这里写图片描述

LSTM的训练

化繁为简,这里只讨论包含一个LSTM层的三层神经网络(如果有多个层则误差项除了沿时间反向传播外,还会向上一层传播),LSTM向前传播时与三个门相关的公式如下,

ft=σ(Wf[ht1,xt]+bf)

it=σ(Wi[ht1,xt]+bi)

c~t=tanh(Wc[ht1,xt]+bc)

ct=ftct1+itc~t

ot=σ(Wo[ht1,xt]+bo)

ht=ottanh(ct)

需要学习的参数挺多的,同时也可以看到LSTM的输出ht有四个输入分量加权影响,即三个门相关的ftitc~tot,而且其中权重W都是拼接的,所以在学习时需要分割出来,即
Wf=Wfx+Wfh

Wi=Wix+Wih

Wc~=Wc~x+Wc~h

Wo=Wox+Woh

输出层的输入yit=Wyiht,输出为yot=σ(yit)

设某时刻的损失函数为Et=12(ydyot)2,则某样本的损失为

E=Tt=1Et

设当前时刻t的误差项δt=Eht,那么误差沿着时间反向传递则需要计算t-1时刻的误差项δt1,则

δt1=Eht1=Ehththt1=δththt1

LSTM的输出ht可看成是一个复合函数,f[ft(ht1),it(ht1),c~t(ht1),ot(ht1)],由全导数公式有,

htht1=htctctftftnetf,tnetf,tht1+htctctititneti,tneti,tht1+htctctc~tc~tnetc~,tnetc~,tht1+htototneto,tneto,tht1

其中netf,tneti,tnetc~,tneto,t表示对应函数的输入。将上述所有偏导都求出来,

\frac{\partial{{h_t}}}{\partial{{c}_t}}={o}_t \ast (1-\tanh({c}_t)^2) \frac{\partial{{c}_t}}{\partial{{f_{t}}}}={c}_{t-1} \frac{\partial{{f}_t}}{\partial{{net}_{f,t}}}={f}_t \ast (1-{f}_t) \frac{\partial{{net}_{f,t}}}{\partial{{h}_{t-1}}}=W_{fh}

htct=ot(1tanh(ct)2)ctft=ct1ftnetf,t=ft(1ft)netf,tht1=Wfh

同样地,其他也可以求出来,最后得到t时刻和t-1时刻之间的关系。再设
\delta_{f,t}=\frac{\partial{E}}{\partial{{net}_{f,t}}} \delta_{i,t}=\frac{\partial{E}}{\partial{{net}_{i,t}}} \delta_{\tilde{c},t}=\frac{\partial{E}}{\partial{{net}_{\tilde{c},t}}} \delta_{o,t}=\frac{\partial{E}}{\partial{{net}_{o,t}}}

δf,t=Enetf,tδi,t=Eneti,tδc~,t=Enetc~,tδo,t=Eneto,t

得到,
δt1=δf,tWfh+δi,tWih+δc~,tWch+δo,tWoh

接着对某时刻t的所有权重进行求偏导,

EWfh,t=Enetf,tnetf,tWfh,t=δf,tht1

EWih,t=Eneti,tneti,tWih,t=δi,tht1

EWch,t=Enetc~,tnetc~,tWch,t=δc~,tht1

EWoh,t=Eneto,tneto,tWoh,t=δo,tht1

EWfx=Enetf,tnetf,tWfx=δf,txt

EWix=Eneti,tneti,tWix=δi,txt

EWcx=Enetc~,tnetc~,tWcx=δc~,txt

EWox=Eneto,tneto,tWox=δo,txt

Ebo,t=Eneto,tneto,tbo,t=δo,t

Ebf,t=Enetf,tnetf,tbf,t=δf,t

Ebi,t=Eneti,tneti,tbi,t=δi,t

Ebc,t=Enetc~,tnetc~,tbc,t=δc~,t

对于整个样本,它的误差是所有时刻的误差之和,而与上个时刻相关的权重的梯度等于所有时刻的梯度之和,其他权重则不必累加,最终得到

EWfh=j=1tδf,jhj1

EWih=j=1tδi,jhj1

EWch=j=1tδc~,jhj1

EWoh=j=1tδo,jhj1

Ebf=j=1tδf,j

Ebi=j=1tδi,j

Ebc=j=1tδc~,j

Ebo=j=1tδo,j

EWfx=Enetf,tnetf,tWfx=δf,txt

EWix=Eneti,tneti,tWix=δi,txt

EWcx=Enetc~,tnetc~,tWcx=δc~,txt

EWox=Eneto,tneto,tWox=δo,txt

相关阅读:
循环神经网络
卷积神经网络
机器学习之神经网络
机器学习之感知器
神经网络的交叉熵损失函数

========广告时间========

公众号的菜单已分为“分布式”、“机器学习”、“深度学习”、“NLP”、“Java深度”、“Java并发核心”、“JDK源码”、“Tomcat内核”等,可能有一款适合你的胃口。

鄙人的新书《Tomcat内核设计剖析》已经在京东销售了,有需要的朋友可以购买。感谢各位朋友。

为什么写《Tomcat内核设计剖析》

=========================

欢迎关注:

这里写图片描述

版权声明:本文为博主原创文章,未经博主允许不得转载。

RNN以及LSTM的介绍和公式梳理

前言好久没用正儿八经地写博客了,csdn居然也有了markdown的编辑器了,最近花了不少时间看RNN以及LSTM的论文,在组内『夜校』分享过了,再在这里总结一下发出来吧,按照我讲解的思路,理解RNN...

循环神经网络(RNN, Recurrent Neural Networks)介绍

循环神经网络(RNN, Recurrent Neural Networks)介绍   这篇文章很多内容是参考:http://www.wildml.com/2015/09/recurrent-neura...

神经网络的前向传播和误差反向传播(NN,RNN,LSTM)(二)

本文转自:http://blog.csdn.net/u011414416/article/details/46694877  ,

从NN到RNN再到LSTM(1):神经网络NN前馈和误差反向传播

本文将简要介绍神经网络(Neural Network,NN)的相关公式及推导,包括前馈和误差反向传播过程。...

长短记忆型递归神经网络LSTM

原文链接http://www.csdn.net/article/2015-11-25/2826323?ref=myread 摘要:作者早前提到了人们使用RNNs取得的显著成效,基本上...

简单理解LSTM神经网络

递归神经网络 在传统神经网络中,模型不会关注上一时刻的处理会有什么信息可以用于下一时刻,每一次都只会关注当前时刻的处理。举个例子来说,我们想对一部影片中每一刻出现的事件进行分类,如果我们知道电影前面...

人人都能用Python写出LSTM-RNN的代码![你的神经网络学习最佳起步]

我的最佳学习法就是通过玩具代码,一边调试一边学习理论。这篇博客通过一个非常简单的python玩具代码来讲解循环神经网络。 那么依旧是废话少说,放‘码’过来!...
  • zzukun
  • zzukun
  • 2015年11月21日 22:22
  • 42950

LSTM神经网络的详细推导及C++实现

LSTM隐层神经元结构: LSTM隐层神经元详细结构: //让程序自己学会是否需要进位,从而学会加法#include "iostream" #include "math.h" #i...

深入理解LSTM神经网络

LSTM核心思想及详细解读
  • menc15
  • menc15
  • 2017年05月06日 16:35
  • 1496

Delphi7高级应用开发随书源码

  • 2003年04月30日 00:00
  • 676KB
  • 下载
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:LSTM神经网络
举报原因:
原因补充:

(最多只允许输入30个字)