tf.nn.dynamic_rnn和tf.nn.static_rnn

一、tf.nn.dynamic_rnn

1. 函数定义

tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

2. 参数说明

具体可以参考:TensorFlow函数教程:tf.nn.dynamic_rnn
输入参数:

  • cell:RNNCell的一个实例。
  • inputs:输入数据。
    如果time_major == False(默认),则是一个shape为[batch_size, max_time, …]的Tensor,或者这些元素的嵌套元组。
    如果time_major == True,则是一个shape为[max_time, batch_size, …]的Tensor,或这些元素的嵌套元组。

输出参数:

  • outputs:RNN输出的Tensor。
    如果time_major == False(默认),outputs为shape为[batch_size, max_time, cell.output_size]的Tensor。
    若要输出对应最后的 last_output,我们必须要先对outputs进行转置,即 output=tf.transpose(outputs, [1,0,2]),此时对应输出output[-1]才是对应最终的输出。
    如果time_major == True,outputs为shape为[max_time, batch_size, cell.output_size]的Tensor。此时无需转置,outputs[-1]即为最终输出。
  • state:最终的状态。即序列中最后一个cell输出的状态。
    如果cell.state_size是int,则state的shape为[batch_size, cell.state_size]。
    如果它是TensorShape,则将形成[batch_size] + cell.state_size。
    如果它是一个(可能是嵌套的)int或TensorShape元组,那么这将是一个具有相应shape的元组。
    如果单元格是LSTMCells,则state将是包含每个单元格的LSTMStateTuple的元组。

注:
batch_size是输入的这批数据的数量
max_time是这批数据中序列的最长长度,即max(time_steps)
cell.output_size是rnn cell中神经元的个数,具体可参考具体RNNCell类中output_size函数的具体实现。(其值通常是是rnn cell中神经元的个数)

3. 代码实例

# 输入数据X的shape为[batch_size,time_steps,input_dim]
cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_units, forget_bias=1.0, state_is_tuple=True)
init_state = cell.zero_state(batch_size, dtype=tf.float32)
outputs, state = tf.nn.dynamic_rnn(cell, X, initial_state=init_state, time_major=False)
# 若是要取最后的输出数据,则需要先进行转置,将time_steps维度提到前面来
outputs = tf.transpose(outputs,[1,0,2])
last_output = outputs[-1]

二、tf.nn.static_rnn

这里 tf.nn.static_rnn 和 tf.contrib.rnn.static_rnn 其实本质是一样的。

1. 函数定义

tf.nn.static_rnn(
    cell,
    inputs,
    initial_state=None,
    dtype=None,
    sequence_length=None,
    scope=None
)

2. 参数说明

具体可以参考:TensorFlow函数教程:tf.nn.static_rnn
输入参数:

  • cell:RNNCell的一个实例。
  • inputs:输入数据。输入长度为time_steps的list,其中每个Tensor的shape为[batch_size, input_size];或这些元素的嵌套元组。
    这里举例简单说明一下:
    原始输入X:[time_steps, batch_size, input_dim]
    对原始输入进行处理得到list:
    input=tf.unstack(X ,time_steps,1)

输出参数:

  • outputs是长度为time_steps的列表,其中每个Tensor的shape为 [batch_size,n_hidden]。因此对应输出最后一个outputs[-1]即为我们需要的值。
  • state是最终状态

3. 代码实例

# 输入数据X的shape为[batch_size,time_steps,input_dim]
input = tf.unstack(X, time_steps, 1)  
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units, forget_bias=1)
outputs, state = tf.rnn.static_rnn(cell, input)

补充说明:outputs和state有什么关系?

结论上来说,如果cell为LSTM,那么state是个tuple,分别代表Ct和Ht,其中Ht与outputs中的对应的最后一个时刻的输出相等。
假设state形状为[2, batch_size, cell.output_size ],outputs形状为 [batch_size, max_time, cell.output_size ],那么state[1, batch_size, : ] == outputs[ batch_size, -1, : ];如果cell为GRU,那么同理,state其实就是Ht,state ==outputs[-1]

参考资料:
tf.nn.dynamic_rnn的输出outputs和state含义

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值