Tensorflow rnn api 阅读笔记
概述
tensorflow rnn api
rnn api 是用于将rnncell按照时间步骤进行自环的
常用rnn
- 带有static前缀的api要求输入的序列具有固定长度。
- 带有dynamic前缀的api可以选择输入一个sequence_length(可以是一个list)参数,该参数对应的是输入sequence的序列长度,用来动态处理sequence的长度(代码中是设置了一个专门记录序列长度的tensor,控制rnn自环的轮数)。
- tf.nn.raw_rnn是原始的rnn api,能够使用该api实现各种定制化的操作。很好用。
- tf.nn.static_rnn
- tf.nn.static_state_saving_rnn
- tf.nn.static_bidirectional_rnn
- tf.nn.stack_bidirectional_dynamic_rnn
- tf.nn.dynamic_rnn
- tf.nn.dynamic_bidirectional_rnn
- tf.nn.raw_rnn
dynamic_rnn
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
dtype=None, parallel_iterations=None, swap_memory=False,
time_major=False, scope=None):
"""Creates a recurrent neural network specified by RNNCell `cell`.
Performs fully dynamic unrolling of `inputs`.
Example:
```python
# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
```
上下两个为官方给的使用样例
# create 2 LSTMCells
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.contrib.rnn.LSTMStateTuple for each cell
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
inputs=data,
dtype=tf.float32)
Args:
cell: An instance of RNNCell.
inputs: The RNN inputs.
If `time_major == False` (default), this must be a `Tensor` of shape:
`[batch_size, max_time, ...]`, or a nested tuple of such
elements.
If `time_major == True`, this must be a `Tensor` of shape:
`[max_time, batch_size, ...]`, or a nested tuple of such
elements.
This may also be a (possibly nested) tuple of Tensors satisfying
this property. The first two dimensions must match across all the inputs,
but otherwise the ranks