tf.dynamic的输入参数中包含了一个sequence_lengths参数,传递的是一个batch中序列的真是长度,这个参数默认为None,如果输入的batch中每个样本的序列长度不相同,那么得到的通过dynamic_rnn得到的outputs每一个时间步的输出都不全为0(意思是和static_rnn一样,把padding部分得到的输出也算进来了),如果这时候我想取到真是长度位置的输出要怎么办?
def extract_axis_1(data, ind):
"""
Get specified elements along the first axis of tensor.
:param data: Tensorflow tensor that will be subsetted.
:param ind: Indices to take (one for each element along axis 0 of data).
:return: Subsetted tensor.
"""
batch_range = tf.range(tf.shape(data)[0])
indices = tf.stack([batch_range, ind], axis=1)
res = tf.gather_nd(data, indices)
return res
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=None,
inputs=X)
outputs2 = extract_axis_1(outputs, [5, 3, 5]) #真是长度是[6,4,6]
output2就满足了你想要的输出
当你把真是的长度传递给dynamic_rnn时,它会根据真实长度大小刚好计算到真实长度为止,这时,只需要取last_states.h即可。