tensorflow之dynamic_rnn

转自:tensorflow高阶教程:tf.dynamic_rnn

tensorflow 的dynamic_rnn方法,我们用一个小例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,last_states,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,last_states是由(c,h)组成的tuple,均为[batch,128]。

到这里并没有什么不同,但是dynamic有个参数:sequence_length,这个参数用来指定每个example的长度,比如上面的例子中,我们令 sequence_length为[20,13],表示第一个example有效长度为20,第二个example有效长度为13,当我们传入这个参数的时候,对于第二个example,TensorFlow对于13以后的padding就不计算了,其last_states将重复第13步的last_states直至第20步,而outputs中超过13步的结果将会被置零。

实例

import tensorflow as tf
import numpy as np
# 创建输入数据
X = np.random.randn(2, 10, 8)
X[1,6:]=0
X_lengths=[10,6]
cell = tf.contrib.rnn.BasicLSTMCell(num_units=3, state_is_tuple=True)

outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)


sess=tf.Session()
sess.run(tf.global_variables_initializer())
print('outputs shpae{}:\n{}'.format(sess.run(outputs).shape,sess.run(outputs)))
print('outputs_last shpae{}:\n{}'.format(sess.run(outputs[:,-1,:]).shape,sess.run(outputs[:,-1,:])))
print('h shpae{}:\n{}'.format(sess.run(last_states.h).shape,sess.run(last_states.h)))
print('c shpae{}:\n{}'.format(sess.run(last_states.c).shape,sess.run(last_states.c)))

输出结果:

outputs shpae(2, 10, 3):
[[[ 0.23110803 -0.0107572   0.20172482]
  [ 0.36795051 -0.00178671 -0.02944196]
  [ 0.14378943 -0.06261828 -0.02818266]
  [-0.01201567 -0.12757061 -0.03590098]
  [-0.04926666 -0.17189351 -0.11857959]
  [-0.02837513 -0.08652785  0.09262549]
  [ 0.03387051 -0.02545082  0.18292322]
  [ 0.04326441 -0.09480653  0.21489134]
  [ 0.22599051  0.02791138  0.15339913]
  [ 0.13498046  0.27791892  0.24671301]]

 [[ 0.04049265 -0.12261707  0.14065445]
  [-0.06155683 -0.22343117  0.21023283]
  [-0.0331372  -0.23284029  0.13399217]
  [-0.15608016 -0.35463616 -0.05002852]
  [-0.19984934 -0.0380892  -0.16395056]
  [-0.20881976 -0.24745148 -0.18732526]
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]
  [ 0.          0.          0.        ]]]
outputs_last shpae(2, 3):
[[ 0.13498046  0.27791892  0.24671301]
 [ 0.          0.          0.        ]]
h shpae(2, 3):
[[ 0.13498046  0.27791892  0.24671301]
 [-0.20881976 -0.24745148 -0.18732526]]
c shpae(2, 3):
[[ 0.16218479  0.98683522  0.42632254]
 [-0.6683409  -0.33597254 -0.26919713]]

可以看出,对于第二个example超过6步的outputs,是直接被设置成0了,而last_states将7-10步的输出重复第6步的输出。可见节省了不少的计算开销

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值