LSTM
理论基础
长短期记忆网络
这篇博客《如何简单的理解LSTM——其实没有那么复杂》)介绍的很好
英文原文《Understanding LSTM Networks》
图片来自上述博客
在应用的时候,我们只需要处理外部的三个变量 h t h_t ht, c t c_t ct, x t x_t xt
pytorch使用解析
LSTM相关的有两个已经包装好的类 LSTM
和LSTMCell
区别在于:
- LSTM类的默认输入一系列时间步,然后它你能够自动处理每一层的输出,不需要手写前向传播
- LSTMCell是LSTM的一个子节点,前向传播的输入输出需要我们自己写。
关于LSTM调用需要注意的是:
- 默认的输入时间步在第一维度,即:(time_step, batch_size, features_nums),在创建LSTM类时,可以指定batch_first=True来变成( batch_size, time_step,features_nums)。在指定后,输出也会随之变化为batch_size在第一维度
- 在前向传播时,由于第一步,我们没有上一层的h0, c0, 可以将h0, c0不写,则默认为0,或者初始化为正态分布(效果可能会更好一点)
- LSTM初始化可以指定多层,使用
num_layers
,默认为1 - LSTM调用输出:
output, (hn, cn) = rnn(input, (h0, c0))
- output : 包含每一层的h的值,最后一个就是hn
(seq_len, batch, num_directions * hidden_size)
- hn: 最后一层的。如果是一维的,就在output的最后一个维度
(num_layers * num_directions, batch, hidden_size)
- cn: 最后一层的cn值
- output : 包含每一层的h的值,最后一个就是hn
关于LSTMCell类注意的是:
- 我们需要自己写forward过程,因为它只是一个单元。各个单元的连接处理要自己写
Encoder-Decoder 和Attention机制
Encoder和Decoder都可以是RNN, CNN, LSTM, GRU等。
这个知乎回答讲解的很好知乎回答
注意的是,我们使用LSTM做为decoder,在pytorch中需要使用LSTMCell来自己写过程。