深度学习基础技术分析6:LSTM(含代码分析)

1. 模型图示

LSTM 模型如 图1 所示。横向穿过 cell 上部的线分别称作 c \mathbf{c} c 总线,下部的线称为 h \mathbf{h} h 总线,这意味着 c t − 1 \mathbf{c}_{t - 1} ct1 h t − 1 \mathbf{h}_{t - 1} ht1 会对 t t t 时刻的计算产生影响 。其中:

  1. x t x_t xt 与下

在这里插入图片描述图1. LSTM 模型

2. 相关技术

LSTM 从名称来看,是用于处理长短时序。

3. 代码分析

程序代码见: https://github.com/garstka/char-rnn-java
为了学习它, 我又来逐个方法来分析.

// 前向传播核心代码
// acts 根据字符串存取实型二维数组
public void active(int t, Map<String, DoubleMatrix> acts) {
    // 获取 t 时刻输入
    DoubleMatrix x = acts.get("x" + t);
    // 上一时刻的 h 和 c
    DoubleMatrix preH = null, preC = null;
    if (t == 0) {
        preH = new DoubleMatrix(1, getOutSize());
        preC = preH.dup();
    } else {
        preH = acts.get("h" + (t - 1));
        preC = acts.get("c" + (t - 1));
    }
    
    DoubleMatrix i = Activer.logistic(x.mmul(Wxi).add(preH.mmul(Whi)).add(preC.mmul(Wci)).add(bi));
    DoubleMatrix f = Activer.logistic(x.mmul(Wxf).add(preH.mmul(Whf)).add(preC.mmul(Wcf)).add(bf));
    DoubleMatrix gc = Activer.tanh(x.mmul(Wxc).add(preH.mmul(Whc)).add(bc));
    DoubleMatrix c = f.mul(preC).add(i.mul(gc));
    DoubleMatrix o = Activer.logistic(x.mmul(Wxo).add(preH.mmul(Who)).add(c.mmul(Wco)).add(bo));
    DoubleMatrix gh = Activer.tanh(c);
    DoubleMatrix h = o.mul(gh);
    
    // 存储各个二维矩阵
    acts.put("i" + t, i);
    acts.put("f" + t, f);
    acts.put("gc" + t, gc);
    acts.put("c" + t, c);
    acts.put("o" + t, o);
    acts.put("gh" + t, gh);
    acts.put("h" + t, h);
}

在我运行的程序中, x t x_t xt 为 one-hot 编码的 1 × 62 1 \times 62 1×62 向量, i t i_t it h t h_t ht 均为 1 × 100 1 \times 100 1×100 向量.

代码所表示的信息比 图1 更丰富。矩阵变量之间要运算,很多时候要乘以权重矩阵。为使得结构更清晰,图1 牺牲了表达的准确性。以下将向量的计算翻译成数学表达式,这些向量都会被存储在模型中。

  1. 向量 i t \mathbf{i}_t it 表示 t t t 时刻输入:
    i t = σ ( W x i ⋅ x t + W h i ⋅ h t − 1 + W c i ⋅ c t − 1 + b i ) (1) \mathbf{i}_t = \sigma(\mathbf{W}^{xi} \cdot \mathbf{x}_t + \mathbf{W}^{hi} \cdot \mathbf{h}_{t - 1} + \mathbf{W}^{ci} \cdot \mathbf{c}_{t - 1} + bi) \tag{1} it=σ(Wxixt+Whiht1+Wcict1+bi)(1)
  2. 向量 f t \mathbf{f}_t ft 表示遗忘:
    i t = σ ( W x f ⋅ x t + W h f ⋅ h t − 1 + W c f ⋅ c t − 1 + b f ) (2) \mathbf{i}_t = \sigma(\mathbf{W}^{xf} \cdot \mathbf{x}_t + \mathbf{W}^{hf} \cdot \mathbf{h}_{t - 1} + \mathbf{W}^{cf} \cdot \mathbf{c}_{t - 1} + bf) \tag{2} it=σ(Wxfxt+Whfht1+Wcfct1+bf)(2)
  3. 向量 g c t \mathbf{gc}_t gct 表示
    g c t = t a n h ( W x c ⋅ x t + W h c ⋅ h t − 1 + b c ) (3) \mathbf{gc}_t = tanh(\mathbf{W}^{xc} \cdot \mathbf{x}_t + \mathbf{W}^{hc} \cdot \mathbf{h}_{t - 1} + bc) \tag{3} gct=tanh(Wxcxt+Whcht1+bc)(3)
  4. 向量 c t \mathbf{c}_t ct 表示
    c t = tanh ⁡ ( f ⊙ c t − 1 + i t ⊙ g c t ) (4) \mathbf{c}_t = \tanh(\mathbf{f} \odot \mathbf{c}_{t - 1} + \mathbf{i}_{t} \odot \mathbf{gc}_t) \tag{4} ct=tanh(fct1+itgct)(4)
  5. 向量 o t \mathbf{o}_t ot 表示
    o t = σ ( W x o ⋅ x t + W h o ⋅ h t − 1 + W c o ⋅ c t + b o ) (5) \mathbf{o}_t = \sigma(\mathbf{W}^{xo} \cdot \mathbf{x}_t + \mathbf{W}^{ho} \cdot \mathbf{h}_{t - 1} + \mathbf{W}^{co} \cdot \mathbf{c}_t + bo) \tag{5} ot=σ(Wxoxt+Whoht1+Wcoct+bo)(5)
  6. 向量 g h t \mathbf{gh}_t ght 表示
    g h t = tanh ⁡ ( c t ) (6) \mathbf{gh}_t = \tanh(\mathbf{c}_t) \tag{6} ght=tanh(ct)(6)
  7. 向量 h t \mathbf{h}_t ht 表示本时刻输出.
    h t = o t ⊙ g h t (7) \mathbf{h}_t = \mathbf{o}_t \odot \mathbf{gh}_t \tag{7} ht=otght(7)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值