TensorFlow 源码分析 -- 循环神经网络(RNN / LSTM / GRU )

概述

循环神经网络(RNN)用于处理序列标注等问题,在自然语言处理、语音识别等有很广泛的用途。LSTM和GRU是目前使用最广泛的两个循环神经网络的模型变种,TensorFlow中已经把这些常用的模型封装的很好,但是在实际工作中,我们经常需要根据需求的不同对LSTM或者GRU进行一些修改,甚至需要重新实现一种RNN模型,本篇文章主要带领读者分析了解一下TensorFlow中RNN系列模型的实现原理,让读者在实现自己的RNN模型时有个参考。

基本概念

本篇主要介绍TensorFlow中RNN系列模型的源码,下面只简单回顾一下相关模型的数学公式,具体原理还请自行查找其它资料。

RNN

h_{t}=\sigma (W_{f}\cdot [h_{t-1},x_{t}]+b_{f})

LSTM

遗忘门:f_{t}=\sigma (W_{f}\cdot [h_{t-1},x_{t}]+b_{f})

输入门:i_{t}=\sigma (W_{i}\cdot [h_{t-1},x_{t}]+b_{i})

输出门:o_{t}=\sigma (W_{o}\cdot [h_{t-1},x_{t}]+b_{o})

状态层:\tilde{C}=tanh(W_{c}[h_{t-1},x_{t}]+b_{c})

                C_{t}=f_{t}\cdot C_{t-1}+i_{t}\cdot \tilde{C_{t}}

输出层:h_{t}=o_{t}\cdot tanh(C_{t})

GRU

重置门:r_{t}=\sigma (W_{r}\cdot [h_{t-1},x_{t}]+b_{r})

更新门:z_{t}=\sigma (W_{z}\cdot [h_{t-1},x_{t}]+b_{z})

输出层:\tilde{h}=tanh(W_{h}[r_{t}\cdot h_{t-1},x_{t}])

               h_{t}=\(1-z_{t})\cdot h_{t-1} + z_{t}\cdot \tilde{h_{t}}

基本流程

在TensorFlow中,RNN相关的源码主要分为两类,一类是表示基础Cell实现逻辑的类,这些类都继承自RNNCell类,主要包括BasicRNNCell、BasicLSTMCell、GRUCell等。另外一类就是让cell在不同时间轴上运转起来的循环流程控制类,包括动态单向RNN流程类tf.nn.dynamic_rnn、动态双向RNN流程类tf.nn.bidirectional_dynamic_rnn等,本篇文章主要介绍更简单一些的tf.nn.dynamic_rnn。

RNNCells

RNNCell

RNNCell是所有RNN系列模型继承的基础Cell类。其原始代码如下:

class RNNCell(base_layer.Layer):

  def __call__(self, inputs, state, scope=None):
    if scope is not None:
      with vs.variable_scope(scope,
                             custom_getter=self._rnn_get_variable) as scope:
        return super(RNNCell, self).__call__(inputs, state, scope=scope)
    else:
      scope_attrname = "rnncell_scope"
      scope = getattr(self, scope_attrname, None)
      if scope is None:
        scope = vs.variable_scope(vs.get_variable_scope(),
                                  custom_getter=self._rnn_get_variable)
        setattr(self, scope_attrname, scope)
      with scope:
        return super(RNNCell, self).__call__(inputs, state)

  def _rnn_get_variable(self, getter, *args, **kwargs):
    variable = getter(*args, **kwargs)
    if context.in_graph_mode():
      trainable = (variable in tf_variables.trainable_variables() or
                   (isinstance(variable, tf_variables.PartitionedVariable) and
                    list(variable)[0] in tf_variables.trainable_variables()))
    else:
      trainable = variable._trainable  # pylint: disable=protected-access
    if trainable and variable not in self._trainable_weights:
      self._trainable_weights.append(variable)
    elif not trainable and variable not in self._non_trainable_weights:
      self._non_trainable_weights.append(variable)
    return variable

  @property
  def state_size(self):
    raise NotImplementedError("Abstract method")

  @property
  def output_size(self):
    raise NotImplementedError("Abstract method")

  def build(self, _):
    pass

  def zero_state(self, batch_size, dtype):
    state_size = self.state_size
    if hasattr(self, "_last_zero_state"):
      (last_state_size, last_batch_size, last_dtype,
       last_output) = getattr(self, "_last_zero_state")
      if (last_batch_size == batch_size and
          last_dtype == dtype and
          last_state_size == state_size):
        return last_output
    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
      output = _zero_state_tensors(state_size, batch_size, dtype)
    self._last_zero_state = (state_size, batch_size, dtype, output)
    return output

RNNCell的__call__方法的输入参数包括inputs和state,inputs是每个time_step输入的数据,是一个二维Tensor对象,两个维度分别是批次数量和每个具体的输入数据,inputs的shape为[batch_size, input_size];state是上一步Cell后保存下来的状态向量,其第一个维度也是批次数量batch_size,第二个维度根据state_size方法配置的不同,可以是状态向量(此时state的shape为[batch_size, state_size]),也可以是一个tuple,其中包含状态向量和任意自定义的其它数据(此时state的shape为[batch_size, tuple(state_size, ...)])。

在RNNCell的__call__方法中,首先通过调用_rnn_get_variable方法把输入的variable划分为可训练参数和不可训练参数,并分别保存于该类的_trainable_weights和_nontrainable_weights成员变量中,完后调用其父类base_layer.Layer类的__call__方法,在这个方法中,最终调用的实现的Cell的call方法,所有继承自RNNCell的Cell类都需要实现一个call,这个方法的入参即是inputs和state,方法内容是每个Cell的具体计算逻辑,通过RNNCell类__call__方法的封装,使得一个Cell实例可以像一个方法一样被调用。

  def __call__(self, inputs, state, scope=None):
    if scope is not None:
      with vs.variable_scope(scope,
                             custom_getter=self._rnn_get_variable) as scope:
        return super(RNNCell, self).__call__(inputs, state, scope=scope)
    else:
      scope_attrname = "rnncell_scope"
      scope = getattr(self, scope_attrname, None)
      if scope is None:
        scope = vs.variable_scope(vs.get_variable_scope(),
                                  custom_getter=self._rnn_get_variable)
        setattr(self, scope_attrname, scope)
      with scope:
        return super(RNNCell, self).__call__(inputs, state)

  def _rnn_get_variable(self, getter, *args, **kwargs):
    variable = getter(*args, **kwargs)
    if context.in_graph_mode():
      trainable = (variable in tf_variables.trainable_variables() or
                   (isinstance(variable, tf_variables.PartitionedVariable) and
                    list(variable)[0] in tf_variables.trainable_variables()))
    else:
      trainable = variable._trainable  # pylint: disable=protected-access
    if trainable and variable not in self._trainable_weights:
      self._trainable_weights.append(variable)
    elif not trainable and variable not in self._non_trainable_weights:
      self._non_trainable_weights.append(variable)
    return variable
state_size方法和output_size方法需要相应的子Cell继承并实现,这两个方法主要用于设置状态向量state和输出向量output的维度。由于这两个方法都有@property装饰器装饰,其它类可以以属性的方式调用这两个方法。
  @property
  def state_size(self):
    raise NotImplementedError("Abstract method")

  @property
  def output_size(self):
    raise NotImplementedError("Abstract method")

build方法也需要由具体的Cell继承并实现,build是base_layer.Layer类中推荐的方法,其会在__call__方法中被调用,调用前系统已经知道了当前层(Layer)输入参数的维度即类型,在build方法中用于应该调用add_variable()方法进行待学习变量的初始化工作。

  def build(self, _):
    pass

在RNNCell中最后是zero_state方法,用于生成值为0的状态向量的值,如果具体的Cell没有设置初始化状态向量的方法,则会调用此方法生成状态向量的初始值。

  def zero_state(self, batch_size, dtype):
    state_size = self.state_size
    if hasattr(self, "_last_zero_state"):
      (last_state_size, last_batch_size, last_dtype,
       last_output) = getattr(self, "_last_zero_state")
      if (last_batch_size == batch_size and
          last_dtype == dtype and
          last_state_size == state_size):
        return last_output
    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
      output = _zero_state_tensors(state_size, batch_size, dtype)
    self._last_zero_state = (state_size, batch_size, dtype, output)
    return output

总结一下:

每个具体的Cell实现应该继承自RNNCell类,同时实现__init__、build、call、state_size、output_size等方法,其中__init__、build、call三个方法是TensorFlow中实现不同功能层通用的建议流程,首先调用__init__方法保存一些配置,当获取到输入数据的维度和类型后,调用build方法初始化待学习的参数、而call方法中是真正的计算逻辑。state_size和output_size针对RNN Cell特有的两个方法,用于设置状态向量state和输出向量output的维度。

BasicRNNCell

rnn_cell_impl中的BasicRNNCell类实现了基础的RNN功能。

基础的RNN只是一个需要在time_step上做多次前向传播计算的全连接网络,其隐藏层计算就是一个全连接。

在__init__方法中,配置了状态层神经元个数num_units,使用的激活函数activation。

在TensorFlow的不同功能层通用的建议流程,建议在__init__方法中设置base_layer.InputSpec类型的input_spec参数,用于告知输入数据的各种属性,包括输入的维度数量、每个维度的尺寸、以及数据类型等,在BasicRNNCell中设置了输入数据的维度数量为2。

  def __init__(self, num_units, activation=None, reuse=None, name=None):
    super(BasicRNNCell, self).__init__(_reuse=reuse, name=name)
    self.input_spec = base_layer.InputSpec(ndim=2)
    self._num_units = num_units
    self._activation = activation or math_ops.tanh

BasicRNNCell的state_size方法和output_size方法把状态层神经元个数和输出层神经元个数都设置成了num_units。

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

在build方法中,根据输入数据的维度,生成待学习的权重和偏置,其中传入的inputs_shape为[batch_size, input_size],input_depth即输入数据的维数,_kernel的维度为[input_depth + self._num_units, self._num_units],其中输入神经元个数即是[h_{t-1},x_{t}],输出是state_size方法的返回值即num_units。

  def build(self, inputs_shape):
    if inputs_shape[1].value is None:
      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
                       % inputs_shape)

    input_depth = inputs_shape[1].value
    self._kernel = self.add_variable(
        _WEIGHTS_VARIABLE_NAME,
        shape=[input_depth + self._num_units, self._num_units])
    self._bias = self.add_variable(
        _BIAS_VARIABLE_NAME,
        shape=[self._num_units],
        initializer=init_ops.zeros_initializer(dtype=self.dtype))

    self.built = True

call方法中,把输入参数和之前time_step的状态值整合,与_kernel相乘,加上偏置后再调用激活函数,可以看出其实就是做了一次全连接。call方法的返回值包括当前time_step的输出向量,以及当前time_step的状态向量,在普通的RNN中,当前time_step的输出向量和状态向量是一样的。

  def call(self, inputs, state):
    gate_inputs = math_ops.matmul(
        array_ops.concat([inputs, state], 1), self._kernel)
    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
    output = self._activation(gate_inputs)
    return output, output

BasicLSTMCell

rnn_cell_impl中的BasicLSTMCell类实现了基础的LSTM功能。另外rnn_cell_impl中的LSTMCell则实现了更为复杂的功能,比如单元裁剪、投影层等,本篇中我们只介绍BasicLSTMCell类的实现。

在__init__方法中,除了与BasicRNNCell一样的一些基本的配置外,BasicLSTMCell还添加了forget_bias和state_is_tuple两个配置,其中state_is_tuple如果为True,则传入call方法的state,以及call方法返回的当前time_step的state,都是以tuple格式的LSTMStateTuple类保存了当前time_step的状态层和输出层,在lstm中,这两个向量是不一样的;如果state_is_tuple为False,则call方法返回的当前time_step的state,是通过按列拼接的方式把当前time_step的状态层和输出层拼接在一个向量中返回的。

  def __init__(self, num_units, forget_bias=1.0,
               state_is_tuple=True, activation=None, reuse=None, name=None):
    super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name)
    if not state_is_tuple:
      logging.warn("%s: Using a concatenated state is slower and will soon be "
                   "deprecated.  Use state_is_tuple=True.", self)
    self.input_spec = base_layer.InputSpec(ndim=2)
    self._num_units = num_units
    self._forget_bias = forget_bias
    self._state_is_tuple = state_is_tuple
    self._activation = activation or math_ops.tanh

正如之前讨论state_is_tuple参数的时候说到的,在state_size方法中,如果state_is_tuple为True,则state的维度为类似tuple的格式、否则为状态层和输出层拼接在一起,即2倍的num_unites。

  @property
  def state_size(self):
    return (LSTMStateTuple(self._num_units, self._num_units)
            if self._state_is_tuple else 2 * self._num_units)

  @property
  def output_size(self):
    return self._num_units

build方法中,注意kernel权重的输出神经元个数为4*num_units,这里把神经元个数设置成num_units的4倍的原因,是因为在lstm的公式中,遗忘门、输入门、输出门及状态层的计算基数都是[h_{t-1},x_{t}]我们可以把这四个计算放到一起(W_{fioc}\cdot [h_{t-1},x_{t}]+b_{fioc}),在计算完以后再拆开即可。

  def build(self, inputs_shape):
    if inputs_shape[1].value is None:
      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
                       % inputs_shape)

    input_depth = inputs_shape[1].value
    h_depth = self._num_units
    self._kernel = self.add_variable(
        _WEIGHTS_VARIABLE_NAME,
        shape=[input_depth + h_depth, 4 * self._num_units])
    self._bias = self.add_variable(
        _BIAS_VARIABLE_NAME,
        shape=[4 * self._num_units],
        initializer=init_ops.zeros_initializer(dtype=self.dtype))

    self.built = True

最后看一下BasicLSTMCell类的call方法,gate_inputs即使把四个通用计算整合在一起进行计算,完后通过array_ops.split把计算结果分拆成输入门、状态层、遗忘门及输出门,完后根据公式生成新的状态向量及输出向量,根据state_is_tuple参数的值决定是把这两个向量封装进LSTMStateTuple返回,还是拼接在一起返回。

  def call(self, inputs, state):
    sigmoid = math_ops.sigmoid
    one = constant_op.constant(1, dtype=dtypes.int32)
    # Parameters of gates are concatenated into one multiply for efficiency.
    if self._state_is_tuple:
      c, h = state
    else:
      c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)

    gate_inputs = math_ops.matmul(
        array_ops.concat([inputs, h], 1), self._kernel)
    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(
        value=gate_inputs, num_or_size_splits=4, axis=one)

    forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
    # Note that using `add` and `multiply` instead of `+` and `*` gives a
    # performance improvement. So using those at the cost of readability.
    add = math_ops.add
    multiply = math_ops.multiply
    new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
                multiply(sigmoid(i), self._activation(j)))
    new_h = multiply(self._activation(new_c), sigmoid(o))

    if self._state_is_tuple:
      new_state = LSTMStateTuple(new_c, new_h)
    else:
      new_state = array_ops.concat([new_c, new_h], 1)
    return new_h, new_state

GRUCell

GRUCell的__init__方法与之前的Cell类似,唯一的不同是用户可以在此设置权重初始化器及偏置初始化器。

  def __init__(self,
               num_units,
               activation=None,
               reuse=None,
               kernel_initializer=None,
               bias_initializer=None,
               name=None):
    super(GRUCell, self).__init__(_reuse=reuse, name=name)

    # Inputs must be 2-dimensional.
    self.input_spec = base_layer.InputSpec(ndim=2)

    self._num_units = num_units
    self._activation = activation or math_ops.tanh
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer

GRUCell的状态向量state和输出向量output的size都是num_units。

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

在GRU的公式中,重置门和更新门的基数都是[h_{t-1},x_{t}],可以合在一起计算,所以_gate_kernel输出神经元的个数是2*num_units,状态层的基数是[r_{t}\cdot h_{t-1},x_{t}],无法合在一起计算,所以需要单独初始化candidate_kernel参数。

  def build(self, inputs_shape):
    if inputs_shape[1].value is None:
      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
                       % inputs_shape)

    input_depth = inputs_shape[1].value
    self._gate_kernel = self.add_variable(
        "gates/%s" % _WEIGHTS_VARIABLE_NAME,
        shape=[input_depth + self._num_units, 2 * self._num_units],
        initializer=self._kernel_initializer)
    self._gate_bias = self.add_variable(
        "gates/%s" % _BIAS_VARIABLE_NAME,
        shape=[2 * self._num_units],
        initializer=(
            self._bias_initializer
            if self._bias_initializer is not None
            else init_ops.constant_initializer(1.0, dtype=self.dtype)))
    self._candidate_kernel = self.add_variable(
        "candidate/%s" % _WEIGHTS_VARIABLE_NAME,
        shape=[input_depth + self._num_units, self._num_units],
        initializer=self._kernel_initializer)
    self._candidate_bias = self.add_variable(
        "candidate/%s" % _BIAS_VARIABLE_NAME,
        shape=[self._num_units],
        initializer=(
            self._bias_initializer
            if self._bias_initializer is not None
            else init_ops.zeros_initializer(dtype=self.dtype)))

    self.built = True

在call方法中,首先计算重置门和更新门,完后计算状态层,最后计算输出。

  def call(self, inputs, state):
    gate_inputs = math_ops.matmul(
        array_ops.concat([inputs, state], 1), self._gate_kernel)
    gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)

    value = math_ops.sigmoid(gate_inputs)
    r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)

    r_state = r * state

    candidate = math_ops.matmul(
        array_ops.concat([inputs, r_state], 1), self._candidate_kernel)
    candidate = nn_ops.bias_add(candidate, self._candidate_bias)

    c = self._activation(candidate)
    new_h = u * state + (1 - u) * c
    return new_h, new_h

流程控制

RNN是处理序列数据的模型,上面介绍的Cell只是序列数据中每一步需要计算的内容,所以还需要一个流程控制类从输入数据的第一步开始,计算每一步的Cell,保存返回值并传递给下一步的Cell使用,并且把最终的返回值作为整个RNN模型的返回值返回。当今的RNN从计算方向上分为单向RNN和双向RNN;从层数上分为单层RNN和多层RNN,在TensorFlow中,单向RNN控制类为tf.nn.dynamic_rnn,双向RNN控制类为tf.nn.bidirectional_dynamic_rnn,多层RNN是在普通的RNNCell外面添加了一个tensorflow.contrib.rnn.MultiRNNCell类,完后根据单向还是双向还是使用dynamic_rnn或者bidirectional_dynamic_rnn。本文只介绍tf.nn.dynamic_rnn的内容,其它类还请自行参考源码。

dynamic_rnn

在构建RNN模型时,在创建完我们需要的RNNCell类以后,把cell和输入参数传递给tf.nn.dynamic_rnn后,就算完成了基于具体input的RNN模型的创建了,tf.nn.dynamic_rnn返回每一步的输出outputs和最后一步的状态state,如下所示:

# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
 
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
 
# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
 
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)

dynamic_rnn方法接受的参数除了cell和inputs以外,sequence_length用于标识inputs中每一个输入的真实序列长度,因为inputs其实是一批数据,准备数据时需要按照其中最长的一个输入的长度定义inputs第二维的长度,其它短些的数据需要在后面补位,如果指定了sequence_length,那么在计算每一个输入数据时,只会计算真实长度的内容,剩下的补位数据不计算。initial_state可以定义状态的初始化内容。

def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
                dtype=None, parallel_iterations=None, swap_memory=False,
                time_major=False, scope=None):

一般情况下用户整理好的输入数据的维度是[batch_size, time_step, input_size]的,在dynamic_rnn方法中为了方便计算,需要把数据的维度转换成[time_stet, batch_size, input_size],如下所示,_transpose_batch_time方法即负责此转换:

flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)

如果用户传递了状态state的初始化值,则把此值直接设置为state的初始化值,否则调用RNNCell父类的zero_state把state初始化为全0的值。

    if initial_state is not None:
      state = initial_state
    else:
      if not dtype:
        raise ValueError("If there is no initial_state, you must give a dtype.")
      state = cell.zero_state(batch_size, dtype)

完后就调用_dynamic_rnn_loop方法进行下一步处理了。_dynamic_rnn_loop方法返回了每一步的输出outputs和最后一步的状态final_state。

(outputs, final_state) = _dynamic_rnn_loop(
        cell,
        inputs,
        state,
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory,
        sequence_length=sequence_length,
        dtype=dtype)

可以看到dynamic_rnn方法只是对传入的参数做了一些校验,把输入参数根据计算需要改变了一下维度顺序,完后设置了状态state的初始化值,同时还设置了parallel_iterations并行化度和sequence_length等参数,完后就交给_dynamic_rnn_loop方法处理了。

_dynamic_rnn_loop

在_dynamic_rnn_loop方法中,主要创建了output TensorArray和input TensorArray,TensorArray在TensorFlow中可以看做是具有动态size功能的Tensor数组,output和input TensorArray的size都是time_step的值,output值创建了一个空的TensorArray,具体的值会在每一步计算完成后填入TensorArray中,input TensorArray在创建完成后,直接把输入input中的值填入其中。

  # 创建TensorArray
  def _create_ta(name, element_shape, dtype):
    return tensor_array_ops.TensorArray(dtype=dtype,
                                        size=time_steps,
                                        element_shape=element_shape,
                                        tensor_array_name=base_name + name)

  # 生成output TensorArray和input TensorArray
  in_graph_mode = context.in_graph_mode()
  if in_graph_mode:
    output_ta = tuple(
        _create_ta(
            "output_%d" % i,
            element_shape=(tensor_shape.TensorShape([const_batch_size])
                           .concatenate(
                               _maybe_tensor_shape_from_tensor(out_size))),
            dtype=_infer_state_dtype(dtype, state))
        for i, out_size in enumerate(flat_output_size))
    input_ta = tuple(
        _create_ta(
            "input_%d" % i,
            element_shape=flat_input_i.shape[1:],
            dtype=flat_input_i.dtype)
        for i, flat_input_i in enumerate(flat_input))
    # 把input灌入到TensorArray中
    input_ta = tuple(ta.unstack(input_)
                     for ta, input_ in zip(input_ta, flat_input))
  else:
    output_ta = tuple([0 for _ in range(time_steps.numpy())]
                      for i in range(len(flat_output_size)))
    input_ta = flat_input

完后就调用TensorFlow中的循环方法while_loop方法,循环time_steps次,调用_time_step执行cell计算,更新time、output_ta及state的值。

  loop_bound = time_steps
  _, output_final_ta, final_state = control_flow_ops.while_loop(
      cond=lambda time, *_: time < loop_bound,
      body=_time_step,
      loop_vars=(time, output_ta, state),
      parallel_iterations=parallel_iterations,
      maximum_iterations=time_steps,
      swap_memory=swap_memory)

_time_step

首先从inputs中读取当前time_step的input值。

input_t = tuple(ta.read(time) for ta in input_ta)

定义调用cell的方法,传入当前time_step的input和上一步的state。

call_cell = lambda: cell(input_t, state)

如果设置了sequence_length参数,则调用_rnn_step方法根据每个输入的实际长度执行计算,否则直接调用call_cell(),即调用相应cell的call方法。

    if sequence_length is not None:
      (output, new_state) = _rnn_step(
          time=time,
          sequence_length=sequence_length,
          min_sequence_length=min_sequence_length,
          max_sequence_length=max_sequence_length,
          zero_output=zero_output,
          state=state,
          call_cell=call_cell,
          state_size=state_size,
          skip_conditionals=True)
    else:
      (output, new_state) = call_cell()

把当前time_step的output写入output TensorArray中。

output_ta_t = tuple(
          ta.write(time, out) for ta, out in zip(output_ta_t, output))

注意,return的数据,需要把time加1,另外两个返回值为更新后的输出output TensorArray,以及当前time生成的状态state。

return (time + 1, output_ta_t, new_state)

_rnn_step

此方法会根据设置的sequence_length,在执行计算的同时,每个输入数据sequence_length使用空值填充。

def _rnn_step(
    time, sequence_length, min_sequence_length, max_sequence_length,
    zero_output, state, call_cell, state_size, skip_conditionals=False):
  """Calculate one step of a dynamic RNN minibatch.

  Returns an (output, state) pair conditioned on `sequence_length`.
  When skip_conditionals=False, the pseudocode is something like:

  if t >= max_sequence_length:
    return (zero_output, state)
  if t < min_sequence_length:
    return call_cell()

  # Selectively output zeros or output, old state or new state depending
  # on whether we've finished calculating each row.
  new_output, new_state = call_cell()
  final_output = np.vstack([
    zero_output if time >= sequence_length[r] else new_output_r
    for r, new_output_r in enumerate(new_output)
  ])
  final_state = np.vstack([
    state[r] if time >= sequence_length[r] else new_state_r
    for r, new_state_r in enumerate(new_state)
  ])
  return (final_output, final_state)

  Args:
    time: Python int, the current time step
    sequence_length: int32 `Tensor` vector of size [batch_size]
    min_sequence_length: int32 `Tensor` scalar, min of sequence_length
    max_sequence_length: int32 `Tensor` scalar, max of sequence_length
    zero_output: `Tensor` vector of shape [output_size]
    state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
      or a list/tuple of such tensors.
    call_cell: lambda returning tuple of (new_output, new_state) where
      new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
      new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
    state_size: The `cell.state_size` associated with the state.
    skip_conditionals: Python bool, whether to skip using the conditional
      calculations.  This is useful for `dynamic_rnn`, where the input tensor
      matches `max_sequence_length`, and using conditionals just slows
      everything down.

  Returns:
    A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
      final_output is a `Tensor` matrix of shape [batch_size, output_size]
      final_state is either a single `Tensor` matrix, or a tuple of such
        matrices (matching length and shapes of input `state`).

  Raises:
    ValueError: If the cell returns a state tuple whose length does not match
      that returned by `state_size`.
  """

  # Convert state to a list for ease of use
  flat_state = nest.flatten(state)
  flat_zero_output = nest.flatten(zero_output)

  def _copy_one_through(output, new_output):
    # TensorArray and scalar get passed through.
    if isinstance(output, tensor_array_ops.TensorArray):
      return new_output
    if output.shape.ndims == 0:
      return new_output
    # Otherwise propagate the old or the new value.
    copy_cond = (time >= sequence_length)
    with ops.colocate_with(new_output):
      return array_ops.where(copy_cond, output, new_output)

  def _copy_some_through(flat_new_output, flat_new_state):
    # Use broadcasting select to determine which values should get
    # the previous state & zero output, and which values should get
    # a calculated state & output.
    flat_new_output = [
        _copy_one_through(zero_output, new_output)
        for zero_output, new_output in zip(flat_zero_output, flat_new_output)]
    flat_new_state = [
        _copy_one_through(state, new_state)
        for state, new_state in zip(flat_state, flat_new_state)]
    return flat_new_output + flat_new_state

  def _maybe_copy_some_through():
    """Run RNN step.  Pass through either no or some past state."""
    new_output, new_state = call_cell()

    nest.assert_same_structure(state, new_state)

    flat_new_state = nest.flatten(new_state)
    flat_new_output = nest.flatten(new_output)
    return control_flow_ops.cond(
        # if t < min_seq_len: calculate and return everything
        time < min_sequence_length, lambda: flat_new_output + flat_new_state,
        # else copy some of it through
        lambda: _copy_some_through(flat_new_output, flat_new_state))

  # TODO(ebrevdo): skipping these conditionals may cause a slowdown,
  # but benefits from removing cond() and its gradient.  We should
  # profile with and without this switch here.
  if skip_conditionals:
    # Instead of using conditionals, perform the selective copy at all time
    # steps.  This is faster when max_seq_len is equal to the number of unrolls
    # (which is typical for dynamic_rnn).
    new_output, new_state = call_cell()
    nest.assert_same_structure(state, new_state)
    new_state = nest.flatten(new_state)
    new_output = nest.flatten(new_output)
    final_output_and_state = _copy_some_through(new_output, new_state)
  else:
    empty_update = lambda: flat_zero_output + flat_state
    final_output_and_state = control_flow_ops.cond(
        # if t >= max_seq_len: copy all state through, output zeros
        time >= max_sequence_length, empty_update,
        # otherwise calculation is required: copy some or all of it through
        _maybe_copy_some_through)

  if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
    raise ValueError("Internal error: state and output were not concatenated "
                     "correctly.")
  final_output = final_output_and_state[:len(flat_zero_output)]
  final_state = final_output_and_state[len(flat_zero_output):]

  for output, flat_output in zip(final_output, flat_zero_output):
    output.set_shape(flat_output.get_shape())
  for substate, flat_substate in zip(final_state, flat_state):
    if not isinstance(substate, tensor_array_ops.TensorArray):
      substate.set_shape(flat_substate.get_shape())

  final_output = nest.pack_sequence_as(
      structure=zero_output, flat_sequence=final_output)
  final_state = nest.pack_sequence_as(
      structure=state, flat_sequence=final_state)

  return final_output, final_state

 

 

 

 

 

 

 

 

 

 

  • 7
    点赞
  • 23
    收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
©️2022 CSDN 皮肤主题:大白 设计师:CSDN官方博客 返回首页
评论 1

打赏作者

luoyuxiang1022

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值