TensorFlow教程——BasicLSTMCell源码浅析

源码分析

class BasicLSTMCell(RNNCell):

    def __init__(self, num_units, forget_bias=1.0, state_is_tuple=True, activation=None, reuse=None):
        super(BasicLSTMCell, self).__init__(_reuse=reuse)

        self._num_units = num_units
        self._forget_bias = forget_bias
        self._state_is_tuple = state_is_tuple
        self._activation = activation or math_ops.tanh
        self._linear = None

    @property
    def state_size(self):
        return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units)

    @property
    def output_size(self):
        return self._num_units

    def call(self, inputs, state):
        sigmoid = math_ops.sigmoid
        # Parameters of gates are concatenated into one multiply for efficiency.
        if self._state_is_tuple:
            c, h = state
        else:
            c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

        if self._linear is None:
            self._linear = _Linear([inputs, h], 4 * self._num_units, True)

        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = array_ops.split(value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)

        new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
        new_h = self._activation(new_c) * sigmoid(o)

        if self._state_is_tuple:
            new_state = LSTMStateTuple(new_c, new_h)
        else:
            new_state = array_ops.concat([new_c, new_h], 1)
        return new_h, new_state

实现了下面的操作(Source):

这里写图片描述

用公式表示就是:

ftitCt~ot=σ(Wf[ht1,xt]+bf)=tanh(Wi[ht1,xt]+bi)=tanh(WC[ht1,xt]+bC)=σ(Wo[ht1,xt]+bo)(1)(2)(3)(4) (1) f t = σ ( W f [ h t − 1 , x t ] + b f ) (2) i t = t a n h ( W i [ h t − 1 , x t ] + b i ) (3) C t ~ = t a n h ( W C [ h t − 1 , x t ] + b C ) (4) o t = σ ( W o [ h t − 1 , x t ] + b o )

Ctht=ftCt1+itCt~=ottanh(Ct)(5)(6) (5) C t = f t ∗ C t − 1 + i t ∗ C t ~ (6) h t = o t ∗ t a n h ( C t )

从图片和公式可以看到,LSTM Unit有三个输入( Ct1 C t − 1 ht1 h t − 1 xt x t ),三个输出( Ct C t ht h t ht h t )。

从源码来看,init函数有个state_is_tuple=True默认选项,跟BasicRNNCell不同,LSTM Unit的Hidden State是( Ct C t , ht h t )元组。

再来看call函数,下面这一行代码就是计算遗忘门 ft f t 、输入门 it i t Ct~ C t ~ 和输出门 ot o t (未计算激活函数)。

i, j, f, o = array_ops.split(value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)

计算输出 Ct C t ht h t

new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)

简单调用例子

import tensorflow as tf
import numpy as np

batch_size = 3
input_dim = 2
output_dim = 4

inputs = tf.placeholder(dtype=tf.float32, shape=(batch_size, input_dim))
previous_state = (tf.random_normal(shape=(batch_size, output_dim)), tf.random_normal(shape=(batch_size, output_dim)))

cell = tf.contrib.rnn.BasicLSTMCell(num_units=output_dim)
output, (state_c, state_h) = cell(inputs, previous_state)

X = np.ones(shape=(batch_size, input_dim))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    o, s_c, s_h = sess.run([output, state_c, state_h], feed_dict={inputs: X})

    print(X)
    print(previous_state[0].eval())
    print(previous_state[1].eval())

    print(o)
    print(s_c)
    print(s_h)
Out:
[[ 1.  1.]
 [ 1.  1.]
 [ 1.  1.]]
[[ 0.7887547   0.03365576  0.99595201 -0.30815718]
 [ 1.35824049 -0.14050144 -0.23883092  1.9371779 ]
 [ 1.82866812 -0.47094828 -0.62541378 -0.05320958]]
[[-1.66138351  0.94374973  1.26592875  0.36519009]
 [ 1.34611154 -0.76777643 -0.2827355   0.2608321 ]
 [ 0.60226637  0.73776215  0.48991165 -1.5606277 ]]

[[-0.16883671 -0.38824198  0.11436205 -0.06198996]
 [-0.17063858  0.30920011  0.04720744 -0.66977751]
 [ 0.22892874  0.13076623 -0.46329597 -0.03104772]]
[[-0.56581497 -1.39438677  0.21151143 -0.13372165]
 [-1.01093698  0.60569942  0.18325643 -1.23710155]
 [ 0.5200395   0.96590245 -0.75785178 -0.28988069]]
[[-0.16883671 -0.38824198  0.11436205 -0.06198996]
 [-0.17063858  0.30920011  0.04720744 -0.66977751]
 [ 0.22892874  0.13076623 -0.46329597 -0.03104772]]

额外解释

LSTMStateTuple就是一个简单的命名元组:

_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))


class LSTMStateTuple(_LSTMStateTuple):
    __slots__ = ()

    @property
    def dtype(self):
        (c, h) = self
        if c.dtype != h.dtype:
            raise TypeError("Inconsistent internal state: %s vs %s" % (str(c.dtype), str(h.dtype)))
        return c.dtype
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

手撕机

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值