源码分析
class BasicLSTMCell(RNNCell):
def __init__(self, num_units, forget_bias=1.0, state_is_tuple=True, activation=None, reuse=None):
super(BasicLSTMCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation or math_ops.tanh
self._linear = None
@property
def state_size(self):
return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units)
@property
def output_size(self):
return self._num_units
def call(self, inputs, state):
sigmoid = math_ops.sigmoid
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
if self._linear is None:
self._linear = _Linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
实现了下面的操作(Source):
用公式表示就是:
从图片和公式可以看到,LSTM Unit有三个输入( Ct−1 C t − 1 , ht−1 h t − 1 , xt x t ),三个输出( Ct C t , ht h t , ht h t )。
从源码来看,init函数有个state_is_tuple=True默认选项,跟BasicRNNCell不同,LSTM Unit的Hidden State是( Ct C t , ht h t )元组。
再来看call函数,下面这一行代码就是计算遗忘门 ft f t 、输入门 it i t 、 Ct~ C t ~ 和输出门 ot o t (未计算激活函数)。
i, j, f, o = array_ops.split(value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)
计算输出 Ct C t 和 ht h t :
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)
简单调用例子
import tensorflow as tf
import numpy as np
batch_size = 3
input_dim = 2
output_dim = 4
inputs = tf.placeholder(dtype=tf.float32, shape=(batch_size, input_dim))
previous_state = (tf.random_normal(shape=(batch_size, output_dim)), tf.random_normal(shape=(batch_size, output_dim)))
cell = tf.contrib.rnn.BasicLSTMCell(num_units=output_dim)
output, (state_c, state_h) = cell(inputs, previous_state)
X = np.ones(shape=(batch_size, input_dim))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
o, s_c, s_h = sess.run([output, state_c, state_h], feed_dict={inputs: X})
print(X)
print(previous_state[0].eval())
print(previous_state[1].eval())
print(o)
print(s_c)
print(s_h)
Out:
[[ 1. 1.]
[ 1. 1.]
[ 1. 1.]]
[[ 0.7887547 0.03365576 0.99595201 -0.30815718]
[ 1.35824049 -0.14050144 -0.23883092 1.9371779 ]
[ 1.82866812 -0.47094828 -0.62541378 -0.05320958]]
[[-1.66138351 0.94374973 1.26592875 0.36519009]
[ 1.34611154 -0.76777643 -0.2827355 0.2608321 ]
[ 0.60226637 0.73776215 0.48991165 -1.5606277 ]]
[[-0.16883671 -0.38824198 0.11436205 -0.06198996]
[-0.17063858 0.30920011 0.04720744 -0.66977751]
[ 0.22892874 0.13076623 -0.46329597 -0.03104772]]
[[-0.56581497 -1.39438677 0.21151143 -0.13372165]
[-1.01093698 0.60569942 0.18325643 -1.23710155]
[ 0.5200395 0.96590245 -0.75785178 -0.28988069]]
[[-0.16883671 -0.38824198 0.11436205 -0.06198996]
[-0.17063858 0.30920011 0.04720744 -0.66977751]
[ 0.22892874 0.13076623 -0.46329597 -0.03104772]]
额外解释
LSTMStateTuple就是一个简单的命名元组:
_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
class LSTMStateTuple(_LSTMStateTuple):
__slots__ = ()
@property
def dtype(self):
(c, h) = self
if c.dtype != h.dtype:
raise TypeError("Inconsistent internal state: %s vs %s" % (str(c.dtype), str(h.dtype)))
return c.dtype