tf.reverse_sequence()简述
在看bidirectional_dynamic_rnn()的源码的时候,看到了代码中有调用 reverse_sequence()这一方法,于是又回去看了下这个函数的用法,发现还是有点意思的。根据名字就可以能看得出,这个方法主要是用来翻转序列的,就像双线LSTM中在反向传播那里需要从下文往上文处理一样,需要对序列做一个镜像的翻转处理。
先来看一下这个方法的定义:
reverse_sequence(
input,
seq_lengths,
seq_axis=None,
batch_axis=None,
name=None,
seq_dim=None,
batch_dim=None)
- 其中input是输入的需要翻转的目标张量,seq_lengths是一个张量;
- 其元素是input中每一处需要翻转时翻转的长度,在双向LSTM中这个值统一被设为输入语句的长度,代表着整句话都需要被翻转,而实际上张量中的元素值可以是不同的,下面的例子中就可以看出;
- seq_axis和seq_dim的关系,在源码中做了如下操作: