LSTM详解
LSTM(Long Short Term Memory),即一种特殊的RNN形式,主要用来解决长期依赖问题,结构图如下:
可以看到它主要由三个门和一个细胞单元组成。这些门其实都是一种非线性变换,主要目的就是决定有多少信息能够通过。
遗忘门
决定上一时刻的单元状态Ct-1中有多少信息保留到当前时刻Ct
公式:
输入门
决定当前输入有多少信息保存到单元状态Ct
公式:
在此处还有一个细胞单元,即当前时刻的新细胞单元
公式:
最终的当前时刻的细胞单元由输入门和遗忘门共同决定,遗忘门负责决定有多少旧信息要保留下来,输入门决定有多少新信息要保存下来。
公式:
输出门:
决定有多少信息可以被输出。
公式:
当前时刻的隐层状态,由输出门和当前细胞状态共同决定。
代码实现
# LSTM
def lstm_cell(embeding_dim,hidden_dim,h_0,c_0,x_0):
# Weights and Bias for input and hidden tensor
# input
Wi = tf.Variable(tf.random_normal(shape=[embeding_dim, hidden_dim], stddev=0.1))
Ui = tf.Variable(tf.random_norma