转自:tensorflow高阶教程:tf.dynamic_rnn
tensorflow 的dynamic_rnn方法,我们用一个小例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,last_states,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,last_states是由(c,h)组成的tuple,均为[batch,128]。
到这里并没有什么不同,但是dynamic有个参数:sequence_length,这个参数用来指定每个example的长度,比如上面的例子中,我们令 sequence_length为[20,13],表示第一个example有效长度为20,第二个example有效长度为13,当我们传入这个参数的时候,对于第二个example,TensorFlow对于13以后的padding就不计算了,其last_states将重复第13步的last_states直至第20步,而outputs中超过13步的结果将会被置零。
实例
import tensorflow as tf
import numpy as np
# 创建输入数据
X = np.random.randn(2, 10, 8)
X[1,6:]=0
X_lengths=[10,6]
cell = tf.contrib.rnn.BasicLSTMCell(num_units=3, state_is_tuple=True)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
sess=tf.Session()
sess.run(tf.global_variables_initializer())
print('outputs shpae{}:\n{}'.format(sess.run(outputs).shape,sess.run(outputs)))
print('outputs_last shpae{}:\n{}'.format(sess.run(outputs[:,-1,:]).shape,sess.run(outputs[:,-1,:])))
print('h shpae{}:\n{}'.format(sess.run(last_states.h).shape,sess.run(last_states.h)))
print('c shpae{}:\n{}'.format(sess.run(last_states.c).shape,sess.run(last_states.c)))
输出结果:
outputs shpae(2, 10, 3):
[[[ 0.23110803 -0.0107572 0.20172482]
[ 0.36795051 -0.00178671 -0.02944196]
[ 0.14378943 -0.06261828 -0.02818266]
[-0.01201567 -0.12757061 -0.03590098]
[-0.04926666 -0.17189351 -0.11857959]
[-0.02837513 -0.08652785 0.09262549]
[ 0.03387051 -0.02545082 0.18292322]
[ 0.04326441 -0.09480653 0.21489134]
[ 0.22599051 0.02791138 0.15339913]
[ 0.13498046 0.27791892 0.24671301]]
[[ 0.04049265 -0.12261707 0.14065445]
[-0.06155683 -0.22343117 0.21023283]
[-0.0331372 -0.23284029 0.13399217]
[-0.15608016 -0.35463616 -0.05002852]
[-0.19984934 -0.0380892 -0.16395056]
[-0.20881976 -0.24745148 -0.18732526]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]]]
outputs_last shpae(2, 3):
[[ 0.13498046 0.27791892 0.24671301]
[ 0. 0. 0. ]]
h shpae(2, 3):
[[ 0.13498046 0.27791892 0.24671301]
[-0.20881976 -0.24745148 -0.18732526]]
c shpae(2, 3):
[[ 0.16218479 0.98683522 0.42632254]
[-0.6683409 -0.33597254 -0.26919713]]
可以看出,对于第二个example超过6步的outputs,是直接被设置成0了,而last_states将7-10步的输出重复第6步的输出。可见节省了不少的计算开销