TensorFlow raw_rnn - 实现seq2seq模式中将上一时刻的输出作为下一时刻的输入

核心问题

在大部分情况下,RNN的输入序列都是预先定义好的,最为常见的就是训练语料中的sentence。但在序列生成任务中,有时我们希望根据 t 时刻预测出的结果(经过一定变形)作为 t+1 时刻的输入,也就是说一开始我们手中并没有一个完整的句子,往往最开始(t = 0时刻)我们只有一个开始标记"<START>",将<START>输入RNN得到初始时刻的输出\large \mathbf{y_0},然后将\large \mathbf{y_0}(或进行一定的变换)作为下一时刻(t = 1时刻)的输入,即\large \mathbf{x_1} = f(\mathbf{y_0}),再将\large \mathbf{x_1}输入到RNN得到输出\large \mathbf{y_1},以此类推,直到预测到指定长度(或者终止标记"<END>")后停止预测。

可见,这个过程是一个动态的过程,其实现关键是在时刻间进行一定的处理(将 \large t_i 时刻的输出处理后作为\large t_{i+1}时刻的输入),但现在常见的RNN封装都没有提供在计算时序间进行处理的操作(包括dynamic_rnn,它的dynamic只是指的以循环方式进行而不是遍历预先定义好的输入序列,这里不过多介绍,更多可以自行查询)。而tf.nn.raw_rnn则提供了这种更底层细节上的操作支持。

tf.nn.raw_rnn api简介

先给出部分源码,思路很清晰

def raw_rnn(cell, loop_fn,
            parallel_iterations=None, swap_memory=False, scope=None):
  """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`.

  **NOTE: This method is still in testing, and the API may change.**

  This function is a more primitive version of `dynamic_rnn` that provides
  more direct access to the inputs each iteration.  It also provides more
  control over when to start and finish reading the sequence, and
  what to emit for the output.

  For example, it can be used to implement the dynamic decoder of a seq2seq
  model.

  Instead of working with `Tensor` objects, most operations work with
  `TensorArray` objects directly.

  The operation of `raw_rnn`, in pseudo-code, is basically the following:

  ```python
  time = tf.constant(0, dtype=tf.int32)
  (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
      time=time, cell_output=None, cell_state=None, loop_state=None)
  emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
  state = initial_state
  while not all(finished):
    (output, cell_state) = cell(next_input, state)
    (next_finished, next_input, next_state, emit, loop_state) = loop_fn(
        time=time + 1, cell_output=output, cell_state=cell_state,
        loop_state=loop_state)
    # Emit zeros and copy forward state for minibatch entries that are finished.
    state = tf.where(finished, state, next_state)
    emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
    emit_ta = emit_ta.write(time, emit)
    # If any new minibatch entries are marked as finished, mark these.
    finished = tf.logical_or(finished, next_finished)
    time += 1
  return (emit_ta, state, loop_state)
  ...```

  with the additional properties that output and state may be (possibly nested)
  tuples, as determined by `cell.output_size` and `cell.state_size`, and
  as a result the final `state` and `emit_ta` may themselves be tuples.

  A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this:

  ```python
  inputs = tf.placeholder(shape=(max_time, batch_size, input_depth),
                          dtype=tf.float32)
  sequence_length = tf.placeholder(shape=(batch_size,), dtype=tf.int32)
  inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
  inputs_ta = inputs_ta.unstack(inputs)

  cell = tf.contrib.rnn.LSTMCell(num_units)

  def loop_fn(time, cell_output, cell_state, loop_state):
    emit_output = cell_output  # == None for time == 0
    if cell_output is None:  # time == 0
      next_cell_state = cell.zero_state(batch_size, tf.float32)
    else:
      next_cell_state = cell_state
    elements_finished = (time >= sequence_length)
    finished = tf.reduce_all(elements_finished)
    next_input = tf.cond(
        finished,
        lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32),
        lambda: inputs_ta.read(time))
    next_loop_state = None
    return (elements_finished, next_input, next_cell_state,
            emit_output, next_loop_state)

  outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
  outputs = outputs_ta.stack()
 ... ```

其中,关键在于loop_fn()函数,正是使用该函数来处理时序间的输出输入转换

loop_fn是一个函数,这个函数在rnn的相邻时间步之间被调用。  
函数的总体调用过程为:

1. 初始时刻,先调用一次loop_fn,获取第一个时间步的cell的输入,loop_fn中进行读取初始时刻的输入。
2. 进行cell自环 (output, cell_state) = cell(next_input, state)
3. 在t时刻RNN计算结束时,cell有一组输出cell_output和状态cell_state,都是tensor;
4. 到t+1时刻开始进行计算之前,loop_fn被调用,调用的形式为loop_fn( t, cell_output, cell_stat, loop_state),而被期待的输出为:(finished, next_input, initial_state, emit_output, loop_state);
5. RNN采用loop_fn返回的next_input作为输入,initial_state作为状态,计算得到新的输出。

在每次执行(output, cell_state) =  cell(next_input, state)后,执行loop_fn()进行数据的准备和处理。

emit_structure 即上文的emit_output将会按照时间存入emit_ta中。 

loop_state  记录rnn loop的变量的状态。用作记录状态 

Tf.where 是用来实现dynamic的。 


### loop_fn()

```python
(elements_finished, next_input, next_cell_state, emit_output, next_loop_state) = loop_fn(time, cell_output, cell_state, loop_state)

至此,raw_rnn的使用在代码中已经很明确了,主要是按个人需求自定义loop_fn()中的操作。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值