1. 模型图示
LSTM 模型如 图1 所示。横向穿过 cell 上部的线分别称作 c \mathbf{c} c 总线,下部的线称为 h \mathbf{h} h 总线,这意味着 c t − 1 \mathbf{c}_{t - 1} ct−1 与 h t − 1 \mathbf{h}_{t - 1} ht−1 会对 t t t 时刻的计算产生影响 。其中:
- 从 x t x_t xt 与下
图1. LSTM 模型
2. 相关技术
LSTM 从名称来看,是用于处理长短时序。
3. 代码分析
程序代码见: https://github.com/garstka/char-rnn-java
为了学习它, 我又来逐个方法来分析.
// 前向传播核心代码
// acts 根据字符串存取实型二维数组
public void active(int t, Map<String, DoubleMatrix> acts) {
// 获取 t 时刻输入
DoubleMatrix x = acts.get("x" + t);
// 上一时刻的 h 和 c
DoubleMatrix preH = null, preC = null;
if (t == 0) {
preH = new DoubleMatrix(1, getOutSize());
preC = preH.dup();
} else {
preH = acts.get("h" + (t - 1));
preC = acts.get("c" + (t - 1));
}
DoubleMatrix i = Activer.logistic(x.mmul(Wxi).add(preH.mmul(Whi)).add(preC.mmul(Wci)).add(bi));
DoubleMatrix f = Activer.logistic(x.mmul(Wxf).add(preH.mmul(Whf)).add(preC.mmul(Wcf)).add(bf));
DoubleMatrix gc = Activer.tanh(x.mmul(Wxc).add(preH.mmul(Whc)).add(bc));
DoubleMatrix c = f.mul(preC).add(i.mul(gc));
DoubleMatrix o = Activer.logistic(x.mmul(Wxo).add(preH.mmul(Who)).add(c.mmul(Wco)).add(bo));
DoubleMatrix gh = Activer.tanh(c);
DoubleMatrix h = o.mul(gh);
// 存储各个二维矩阵
acts.put("i" + t, i);
acts.put("f" + t, f);
acts.put("gc" + t, gc);
acts.put("c" + t, c);
acts.put("o" + t, o);
acts.put("gh" + t, gh);
acts.put("h" + t, h);
}
在我运行的程序中, x t x_t xt 为 one-hot 编码的 1 × 62 1 \times 62 1×62 向量, i t i_t it 至 h t h_t ht 均为 1 × 100 1 \times 100 1×100 向量.
代码所表示的信息比 图1 更丰富。矩阵变量之间要运算,很多时候要乘以权重矩阵。为使得结构更清晰,图1 牺牲了表达的准确性。以下将向量的计算翻译成数学表达式,这些向量都会被存储在模型中。
- 向量
i
t
\mathbf{i}_t
it 表示
t
t
t 时刻输入:
i t = σ ( W x i ⋅ x t + W h i ⋅ h t − 1 + W c i ⋅ c t − 1 + b i ) (1) \mathbf{i}_t = \sigma(\mathbf{W}^{xi} \cdot \mathbf{x}_t + \mathbf{W}^{hi} \cdot \mathbf{h}_{t - 1} + \mathbf{W}^{ci} \cdot \mathbf{c}_{t - 1} + bi) \tag{1} it=σ(Wxi⋅xt+Whi⋅ht−1+Wci⋅ct−1+bi)(1) - 向量
f
t
\mathbf{f}_t
ft 表示遗忘:
i t = σ ( W x f ⋅ x t + W h f ⋅ h t − 1 + W c f ⋅ c t − 1 + b f ) (2) \mathbf{i}_t = \sigma(\mathbf{W}^{xf} \cdot \mathbf{x}_t + \mathbf{W}^{hf} \cdot \mathbf{h}_{t - 1} + \mathbf{W}^{cf} \cdot \mathbf{c}_{t - 1} + bf) \tag{2} it=σ(Wxf⋅xt+Whf⋅ht−1+Wcf⋅ct−1+bf)(2) - 向量
g
c
t
\mathbf{gc}_t
gct 表示
g c t = t a n h ( W x c ⋅ x t + W h c ⋅ h t − 1 + b c ) (3) \mathbf{gc}_t = tanh(\mathbf{W}^{xc} \cdot \mathbf{x}_t + \mathbf{W}^{hc} \cdot \mathbf{h}_{t - 1} + bc) \tag{3} gct=tanh(Wxc⋅xt+Whc⋅ht−1+bc)(3) - 向量
c
t
\mathbf{c}_t
ct 表示
c t = tanh ( f ⊙ c t − 1 + i t ⊙ g c t ) (4) \mathbf{c}_t = \tanh(\mathbf{f} \odot \mathbf{c}_{t - 1} + \mathbf{i}_{t} \odot \mathbf{gc}_t) \tag{4} ct=tanh(f⊙ct−1+it⊙gct)(4) - 向量
o
t
\mathbf{o}_t
ot 表示
o t = σ ( W x o ⋅ x t + W h o ⋅ h t − 1 + W c o ⋅ c t + b o ) (5) \mathbf{o}_t = \sigma(\mathbf{W}^{xo} \cdot \mathbf{x}_t + \mathbf{W}^{ho} \cdot \mathbf{h}_{t - 1} + \mathbf{W}^{co} \cdot \mathbf{c}_t + bo) \tag{5} ot=σ(Wxo⋅xt+Who⋅ht−1+Wco⋅ct+bo)(5) - 向量
g
h
t
\mathbf{gh}_t
ght 表示
g h t = tanh ( c t ) (6) \mathbf{gh}_t = \tanh(\mathbf{c}_t) \tag{6} ght=tanh(ct)(6) - 向量
h
t
\mathbf{h}_t
ht 表示本时刻输出.
h t = o t ⊙ g h t (7) \mathbf{h}_t = \mathbf{o}_t \odot \mathbf{gh}_t \tag{7} ht=ot⊙ght(7)