双向bidirectional_dynamic_rnn()之 tf.reverse_sequence()详述

11 篇文章 0 订阅
8 篇文章 0 订阅

在看bidirectional_dynamic_rnn()的源码的时候,看到了为何可以双向的核心代码reverse_sequence(),下面就来说说双向bidirectional_dynamic_rnn()是如何进行反转的。

reverse_sequence(
  input,
  seq_lengths,
  seq_axis=None,
  batch_axis=None,
  name=None,
  seq_dim=None,
  batch_dim=None)

其中input是输入的需要翻转的目标张量,seq_lengths是一个张量;
其元素是input中每一处需要翻转时翻转的长度,在双向LSTM中这个值统一被设为输入语句的长度,代表着整句话都需要被翻转

“此操作首先沿着维度batch_axis对input进行分割,并且对于每个切片 i,将前 seq_lengths 元素沿维度 seq_axis 反转”。实际上通俗来理解,就是对于张量input中的第batch_axis维中的每一个子张量,在这个子张量的第seq_axis维上进行翻转,翻转的长度为 seq_lengths 张量中对应的数值。

举个例子,如果 **batch_axis=0,seq_axis=1,则代表我希望每一行为单位分开处理,对于每一行中的每一列进行翻转。**相反的,**如果 batch_axis=1,seq_axis=0,则是以列为单位,对于每一列的张量,进行相应行的翻转。**回头去看双向RNN的源码,就可以理解当time_major这一属性不同时,time_dim 和 batch_dim 这一对组合的取值为什么恰好是相反的了。
举个例子:

a = tf.constant([[1,2,3,4], [5,6,7,8], [9,10,11,12]])

lx = tf.constant([2,4,3],tf.int64) 
# 由于batch_axis= 0,seq_axis = 1,所以按照每一行为单位进行翻转
# 所以lx的维度只能为3,因为只有3行,同时每个维度小于等于每一行的长度4
# 当lx=n是,则表示从第n个位置翻转,第n个位置表示第1个位置,(n-1)表示第2个位置,原来的第一个位置表示第n个位置,完全翻转
x = tf.reverse_sequence(a,seq_lengths=lx,batch_axis= 0,seq_axis = 1)


ly = tf.constant([2,2,3,1],tf.int64) 
# 由于batch_axis= 1,seq_axis = 0,所以按照每一列为单位进行翻转
# 所以ly的维度只能为4,因为只有4列,同时每个维度小于等于每一列的长度3
# 当lx=n是,则表示从第n个位置翻转,第n个位置表示第1个位置,(n-1)表示第2个位置,原来的第一个位置表示第n个位置,完全翻转
y = tf.reverse_sequence(a,seq_lengths=ly,batch_axis= 1,seq_axis = 0)
with tf.Session() as sess:
  print(sess.run(x))
  print('======================')
  print(sess.run(y))

得到结果如下:

[[ 2  1  3  4]
 [ 8  7  6  5]
 [11 10  9 12]]
 ======================
[[ 5  6 11  4]
 [ 1  2  7  8]
 [ 9 10  3 12]]
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值