在看bidirectional_dynamic_rnn()的源码的时候,看到了为何可以双向的核心代码reverse_sequence(),下面就来说说双向bidirectional_dynamic_rnn()是如何进行反转的。
reverse_sequence(
input,
seq_lengths,
seq_axis=None,
batch_axis=None,
name=None,
seq_dim=None,
batch_dim=None)
其中input是输入的需要翻转的目标张量,seq_lengths是一个张量;
其元素是input中每一处需要翻转时翻转的长度,在双向LSTM中这个值统一被设为输入语句的长度,代表着整句话都需要被翻转
“此操作首先沿着维度batch_axis对input进行分割,并且对于每个切片 i,将前 seq_lengths 元素沿维度 seq_axis 反转”。实际上通俗来理解,就是对于张量input中的第batch_axis维中的每一个子张量,在这个子张量的第seq_axis维上进行翻转,翻转的长度为 seq_lengths 张量中对应的数值。
举个例子,如果 **batch_axis=0,seq_axis=1,则代表我希望每一行为单位分开处理,对于每一行中的每一列进行翻转。**相反的,**如果 batch_axis=1,seq_axis=0,则是以列为单位,对于每一列的张量,进行相应行的翻转。**回头去看双向RNN的源码,就可以理解当time_major这一属性不同时,time_dim 和 batch_dim 这一对组合的取值为什么恰好是相反的了。
举个例子:
a = tf.constant([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
lx = tf.constant([2,4,3],tf.int64)
# 由于batch_axis= 0,seq_axis = 1,所以按照每一行为单位进行翻转
# 所以lx的维度只能为3,因为只有3行,同时每个维度小于等于每一行的长度4
# 当lx=n是,则表示从第n个位置翻转,第n个位置表示第1个位置,(n-1)表示第2个位置,原来的第一个位置表示第n个位置,完全翻转
x = tf.reverse_sequence(a,seq_lengths=lx,batch_axis= 0,seq_axis = 1)
ly = tf.constant([2,2,3,1],tf.int64)
# 由于batch_axis= 1,seq_axis = 0,所以按照每一列为单位进行翻转
# 所以ly的维度只能为4,因为只有4列,同时每个维度小于等于每一列的长度3
# 当lx=n是,则表示从第n个位置翻转,第n个位置表示第1个位置,(n-1)表示第2个位置,原来的第一个位置表示第n个位置,完全翻转
y = tf.reverse_sequence(a,seq_lengths=ly,batch_axis= 1,seq_axis = 0)
with tf.Session() as sess:
print(sess.run(x))
print('======================')
print(sess.run(y))
得到结果如下:
[[ 2 1 3 4]
[ 8 7 6 5]
[11 10 9 12]]
======================
[[ 5 6 11 4]
[ 1 2 7 8]
[ 9 10 3 12]]