lstm

lstm_flowj (Fig. 1)

Fig.1 是一张展开的LSTM模型的示意图,绿色的模块表示隐藏层记忆单元,每个记忆单元都有三个输入,两个输出(虽然图中有三个输出箭头,但其中两个输出都是\(h(t)\)),因此,\(t\) 时刻记忆单元的输入、输出分别为 \(x(t), h(t-1), s(t-1)\)\(h(t), s(t)\)

differciate_chain_sh

模型的计算公式如下
\[ \begin{align} g(t) &= \phi(W_{gx}x(t) + W_{gh}h(t-1) + b_g) & (\text{Eq. 1.1}) \\ i(t) &= \sigma(W_{ix}x(t) + W_{ih}h(t-1) + b_i) & (\text{Eq. 1.2}) \\ f(t) &= \sigma(W_{fx}x(t) + W_{fh}h(t-1) + b_f) & (\text{Eq. 1.3}) \\ o(t) &= \sigma(W_{ox}x(t) + W_{oh}h(t-1) + b_o) & (\text{Eq. 1.4}) \\ s(t) &= g(t)*i(t) + s(t-1)*f(t) & (\text{Eq. 1.5}) \\ h(t) &= s(t)*o(t) & (\text{Eq. 1.6}) \end{align} \]
成本函数的定义为:
\[L = \sum_{t=1}^Tl(t)=\frac{1}{2}\sum_{t=1}^T||y(t)-h(t)||^2 \qquad (\text{Eq. 2})\]
其中\(y(t), h(t)\) 分别是目标结果和模型输出。

我们的目标是计算
\[\frac{dL}{dw}=\sum_{t=1}^T\frac{dL}{dh(t)}\frac{dh(t)}{dw} \qquad (\text{Eq. 3})\]
其中, \(\frac{dL}{dh(t)}\)表示成本函\(L\)数对变量\(h(t)\)的全微分,可表示成如下形式
\[ \frac{dL}{dh(t)}=\frac{d}{dh(t)}\sum_{\tau=1}^T l(\tau)=\frac{d}{dh(t)}\sum_{\tau=t}^T l(\tau)=\frac{dL(t)}{dh(t)} \]
我们接下来将推导出以下四个全微分公式,进而计算\(\frac{dL}{dw}\)
\[ \bbox[yellow] { \begin{align} \frac{dL(t)}{dh(t)} \qquad (\text{Eq. 4.1}) \\ \\ \frac{dL(t+1)}{ds(t)} \qquad (\text{Eq. 4.2}) \\ \\ \frac{dL(t)}{dh(t-1)} \qquad (\text{Eq. 4.3}) \\ \\ \frac{dL(t)}{ds(t-1)} \qquad (\text{Eq. 4.4}) \end{align} } \]
其中\(h(t), s(t)\) 代表\(t\)时刻记忆单元的输出值,\(h(t-1), s(t-1)\) 则代表\(t\)时刻记忆单元的输入值。


  • 根据 \(L(t) = l(t) + L(t+1)\),可得
    \[ \frac{dL(t)}{dh(t)}=\frac{dl(t)}{dh(t)}+\frac{dL(t+1)}{dh(t)} \qquad (\text{Eq. 5}) \]

  • 成本函数 \(L\)\(s(t)\) 的全微分形式 \(\frac{dL(t)}{ds(t)}\)
    根据 Eq. 1.6, \(s(t)\) 的值会影响 \(h(t)\), 根据 Eq. 1.5 \(s(t)\) 值会影响 \(s(t+1)\),进而影响 \(h(t+1)\),因此\(\frac{dL}{ds(t)}\)可分为两部分计算

\[ \begin{align} \frac{dL(t)}{ds(t)} & =\frac{dL(t)}{dh(t)}\cdot\frac{dh(t)}{ds(t)}+\frac{dL(t+1)}{dh(t+1)}\cdot\frac{dh(t+1)}{ds(t)} \\ \\ & = \frac{dL(t)}{dh(t)}\cdot\frac{dh(t)}{ds(t)}+\frac{dL(t+1)}{ds(t)} \qquad \qquad (\text{Eq. 6}) \end{align} \]

  • \(t=T\) 时,有初始条件
    \[ \begin{align} &\frac{dL(t)}{dh(t)} = \frac{dl(t)}{dh(t)}=h(t) - y(t) \\ \\ &\frac{dL(t+1)}{ds(t)} = 0 \end{align} \]
    进而可以根据 Eq. 6 求得 \(\frac{dL(t)}{ds(t)}\),且根据 Eq. 1.5 有
    \[ \frac{dL(t)}{ds(t-1)}=\frac{dL(t)}{ds(t)}\cdot f(t) \]
    通过 Eq. 1.1 - 1.6,我们可计算出 \(\frac{dh(t)}{dh(t-1)}\),进而求得
    \[ \frac{dL(t)}{dh(t-1)} = \frac{dL(t)}{dh(t)}\cdot \frac{dh(t)}{dh(t-1)} \]
    至此,我们求得 \(t=T\) 时刻 Eq. 4.1 - Eq. 4.4 四个全微分的值。
  • \(t=T- 1\) 时,
    \[ \begin{align} \frac{dL(t)}{dh(t)} &= \frac{dl(t)}{dh(t)} + \frac{dL(t+1)}{dh(t)} \\ \\ & = h(t) - y(t) + \frac{dL(T)}{dh(T-1)} \\ \\ \frac{dL(t+1)}{ds(t)} &= \frac{dL(T)}{ds(T-1)} \end{align} \]
    此时,与 \(t=T\) 时刻的情况完全一样,可依次求出 \(t=T-2, T-3, ... 2, 1\) 时刻的微分方程 Eq. 4.1-Eq. 4.4,从而求出 Eq. 3

cell
lstm
rnn_types

参考:
[1] http://colah.github.io/posts/2015-08-Understanding-LSTMs/
[2] http://nicodjimenez.github.io/2014/08/08/lstm.html
[3] http://meta.math.stackexchange.com/questions/5020/mathjax-basic-tutorial-and-quick-reference
[4] A Critical Review of Recurrent Neural Networks for Sequence Learning. Zachary C. Lipton John Berkowitz

转载于:https://www.cnblogs.com/huizhu135/p/6060131.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值