tensorflow学习之BasicLSTMCell详解

tf.contrib.rnn.BasicLSTMCell

继承自:LayerRNNCell

Aliases:

  1. Class tf.contrib.rnn.BasicLSTMCell
  2. Class tf.nn.rnn_cell.BasicLSTMCell

基础的LSTM循环网络单元,基于http://arxiv.org/abs/1409.2329.实现。将forget_bias(默认值:1)添加到忘记门的偏差(biases)中以便在训练开始时减少以往的比例(scale)。该神经元不允许单元裁剪(cell clipping),投影层,也不使用peep-hole连接,它是一个基本的LSTM神经元。想要更高级的模型可以使用:tf.nn.rnn_cell.LSTMCell

__init__(

    num_units,

    forget_bias=1.0,

    state_is_tuple=True,

    activation=None,

    reuse=None,

    name=None,

    dtype=None

)

参数说明:

  • num_units:int类型,LSTM单元中的神经元数量,即输出神经元数量
  • forget_bias:float类型,偏置增加了忘记门。从CudnnLSTM训练的检查点(checkpoin)恢复时,必须手动设置为0.0。
  • state_is_tuple:如果为True,则接受和返回的状态是c_state和m_state的2-tuple;如果为False,则他们沿着列轴连接。后一种即将被弃用。
  • activation:内部状态的激活函数。默认为tanh
  • reuse:布尔类型,描述是否在现有范围中重用变量。如果不为True,并且现有范围已经具有给定变量,则会引发错误。
  • name:String类型,层的名称。具有相同名称的层将共享权重,但为了避免错误,在这种情况下需要reuse=True.
  • dtype:该层默认的数据类型。默认值为None表示使用第一个输入的类型。在call之前build被调用则需要该参数。

源码:

class BasicLSTMCell(LayerRNNCell):

  """Basic LSTM recurrent network cell.



  The implementation is based on: http://arxiv.org/abs/1409.2329.



  We add forget_bias (default: 1) to the biases of the forget gate in order to

  reduce the scale of forgetting in the beginning of the training.



  It does not allow cell clipping, a projection layer, and does not

  use peep-hole connections: it is the basic baseline.



  For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}

  that follows.

  """



  def __init__(self, num_units, forget_bias=1.0,

               state_is_tuple=True, activation=None, reuse=None, name=None):

    """初始化基础的 LSTM cell.



    Args:

      num_units:int类型,LSTM单元中的神经元数量,即输出神经元数量

forget_bias:float类型,偏置增加了忘记门。从CudnnLSTM训练的检查点(checkpoin)恢复时,必须手动设置为0.0。

state_is_tuple:如果为True,则接受和返回的状态是c_state和m_state的2-tuple;如果为False,则他们沿着列轴连接。后一种即将被弃用。

activation:内部状态的激活函数。默认为tanh

reuse:布尔类型,描述是否在现有范围中重用变量。如果不为True,并且现有范围已经具有给定变量,则会引发错误。

name:String类型,层的名称。具有相同名称的层将共享权重,但为了避免错误,在这种情况下需要reuse=True.



      When restoring from CudnnLSTM-trained checkpoints, must use

      `CudnnCompatibleLSTMCell` instead.

    """

    super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name)

    if not state_is_tuple:

      logging.warn("%s: Using a concatenated state is slower and will soon be "

                   "deprecated.  Use state_is_tuple=True.", self)



    # Inputs must be 2-dimensional.

    self.input_spec = base_layer.InputSpec(ndim=2)



    self._num_units = num_units

    self._forget_bias = forget_bias

    self._state_is_tuple = state_is_tuple

    self._activation = activation or math_ops.tanh



  @property

  def state_size(self):

    return (LSTMStateTuple(self._num_units, self._num_units)

            if self._state_is_tuple else 2 * self._num_units)



  @property

  def output_size(self):

    return self._num_units



  def build(self, inputs_shape):

    if inputs_shape[1].value is None:

      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"

                       % inputs_shape)



    input_depth = inputs_shape[1].value

    h_depth = self._num_units

    self._kernel = self.add_variable(

        _WEIGHTS_VARIABLE_NAME,

        shape=[input_depth + h_depth, 4 * self._num_units])

    self._bias = self.add_variable(

        _BIAS_VARIABLE_NAME,

        shape=[4 * self._num_units],

        initializer=init_ops.zeros_initializer(dtype=self.dtype))



    self.built = True



  def call(self, inputs, state):

    """Long short-term memory cell (LSTM).



    Args:

      inputs: `2-D` tensor with shape `[batch_size, input_size]`.

      state: An `LSTMStateTuple` of state tensors, each shaped

        `[batch_size, self.state_size]`, if `state_is_tuple` has been set to

        `True`.  Otherwise, a `Tensor` shaped

        `[batch_size, 2 * self.state_size]`.



    Returns:

      A pair containing the new hidden state, and the new state (either a

        `LSTMStateTuple` or a concatenated state, depending on

        `state_is_tuple`).

    """

    sigmoid = math_ops.sigmoid

    one = constant_op.constant(1, dtype=dtypes.int32)

    # Parameters of gates are concatenated into one multiply for efficiency.

    if self._state_is_tuple:

      c, h = state

    else:

      c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)



    gate_inputs = math_ops.matmul(

        array_ops.concat([inputs, h], 1), self._kernel)

    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)



    # i = input_gate, j = new_input, f = forget_gate, o = output_gate

    i, j, f, o = array_ops.split(

        value=gate_inputs, num_or_size_splits=4, axis=one)



    forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)

    # Note that using `add` and `multiply` instead of `+` and `*` gives a

    # performance improvement. So using those at the cost of readability.

    add = math_ops.add

    multiply = math_ops.multiply

    new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),

                multiply(sigmoid(i), self._activation(j)))

    new_h = multiply(self._activation(new_c), sigmoid(o))



    if self._state_is_tuple:

      new_state = LSTMStateTuple(new_c, new_h)

    else:

      new_state = array_ops.concat([new_c, new_h], 1)

    return new_h, new_state

实现了下面的操作:

用公式表示就是:

从图片和公式可知,LSTM单元有单个输入(Ct-1,ht-1,xt),三个输出(Ct,ht,ht)。

构造函数init有个state_is_tuple=True,如果为True,则接受和返回的状态是c_state和m_state的2-tuple;如果为False,则他们沿着列轴连接。

if self._state_is_tuple:
  new_state = LSTMStateTuple(new_c, new_h)
else:
  new_state = array_ops.concat([new_c, new_h], 1)

LSTM单元的隐藏状态是(Ct,ht)元组。

再来看call函数,下面这一行代码就是计算遗忘门、输入门、和输出门(未经过激活函数)。

# i = input_gate, j = new_input, f = forget_gate, o = output_gate

    i, j, f, o = array_ops.split(

        value=gate_inputs, num_or_size_splits=4, axis=one)

计算输出Ct和ht:

new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),

                multiply(sigmoid(i), self._activation(j)))

    new_h = multiply(self._activation(new_c), sigmoid(o))


代码实例:

import tensorflow as tf



output_dim=128

lstm=tf.nn.rnn_cell.BasicLSTMCell(output_dim)

batch_size=10 #批处理大小

timesteps=40 #时间步长

embedding_dim=300 #词向量维度

inputs=tf.Variable(tf.random_normal([batch_size,embedding_dim]))

previous_state = (tf.random_normal(shape=(batch_size, output_dim)), tf.random_normal(shape=(batch_size, output_dim)))

output,(new_h, new_state)=lstm(inputs,previous_state)



print(output.shape) #(10, 128)

print(new_h.shape) #(10, 128)

print(new_state.shape) #(10, 128)

 

  • 13
    点赞
  • 29
    收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
©️2022 CSDN 皮肤主题:书香水墨 设计师:CSDN官方博客 返回首页
评论 1

打赏作者

大雄没有叮当猫

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值