RNN是处理序列数据的模型,上面介绍的Cell只是序列数据中每一步需要计算的内容,所以还需要一个流程控制类从输入数据的第一步开始,计算每一步的Cell,保存返回值并传递给下一步的Cell使用,并且把最终的返回值作为整个RNN模型的返回值返回。当今的RNN从计算方向上分为单向RNN和双向RNN;从层数上分为单层RNN和多层RNN,在TensorFlow中,单向RNN控制类为tf.nn.dynamic_rnn,双向RNN控制类为tf.nn.bidirectional_dynamic_rnn,多层RNN是在普通的RNNCell外面添加了一个tensorflow.contrib.rnn.MultiRNNCell类,完后根据单向还是双向还是使用dynamic_rnn或者bidirectional_dynamic_rnn。
dynamic_rnn
在构建RNN模型时,在创建完我们需要的RNNCell类以后,把cell和输入参数传递给tf.nn.dynamic_rnn后,就算完成了基于具体input的RNN模型的创建了,tf.nn.dynamic_rnn返回每一步的输出outputs和最后一步的状态state,如下所示:
# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
dynamic_rnn方法接受的参数除了cell和inputs以外,sequence_length用于标识inputs中每一个输入的真实序列长度,因为inputs其实是一批数据,准备数据时需要按照其中最长的一个输入的长度定义inputs第二维的长度,其它短些的数据需要在后面补位,如果指定了sequence_length,那么在计算每一个输入数据时,只会计算真实长度的内容,剩下的补位数据不计算。initial_state可以定义状态的初始化内容。
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
dtype=None, parallel_iterations=None, swap_memory=False,
time_major=False, scope=None):
一般情况下用户整理好的输入数据的维度是[batch_size, time_step, input_size]的,在dynamic_rnn方法中为了方便计算,需要把数据的维度转换成[time_stet, batch_size, input_size],如下所示,_transpose_batch_time方法即负责此转换:
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
如果用户传递了状态state的初始化值,则把此值直接设置为state的初始化值,否则调用RNNCell父类的zero_state把state初始化为全0的值。
if initial_state is not None:
state = initial_state
else:
if not dtype:
raise ValueError("If there is no initial_state, you must give a dtype.")
state = cell.zero_state(batch_size, dtype)
完后就调用_dynamic_rnn_loop方法进行下一步处理了。_dynamic_rnn_loop方法返回了每一步的输出outputs和最后一步的状态final_state。
(outputs, final_state) = _dynamic_rnn_loop(
cell,
inputs,
state,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory,
sequence_length=sequence_length,
dtype=dtype)
https://blog.csdn.net/wh1312142954/article/details/80213117
https://blog.csdn.net/luoyuxiang1022/article/details/81866329
https://cairohy.github.io/2017/06/05/ml-coding-summarize/Tensorflow的RNN和Attention相关/