tf.nn.dynamic_rnn()在处理变长输入时特别方便,具体解释可以看这篇博文
#coding=utf-8
import tensorflow as tf
import numpy as np
# 创建输入数据
X = np.random.randn(2, 10, 8)
# 第二个example长度为6
X[1,6:] = 0
X_lengths = [10, 6]
cell = tf.contrib.rnn.BasicLSTMCell(num_units=5, state_is_tuple=True)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
result = tf.contrib.learn.run_n(
{"outputs": outputs, "last_states": last_states},
n=1,
feed_dict=None)
a = result[0]
print(a)
assert result[0]["outputs"].shape == (2, 10, 5)
# 第二个example中的outputs超过6步(7-10步)的值应该为0
assert (result[0]["outputs"][1,7,:] == np.zeros(cell.output_size)).all()
输出:
[[ 0.13337141 0.10697078 0.11238842 -0.16187296 0.04447445]
[ 0.06800554 0.29581101 0.14454009 -0.11857419 -0.08062822]
[-0.02766501 0.20230338 0.25521379 -0.10196185 0.02908 ]
[ 0.07160553 0.39891538 0.03997988 -0.43861938 -0.00340179]
[-0.12841535 0.35346241 0.08577594 -0.29574161 -0.06306395]
[-0.19022677 0.11256105 -0.13190501 -0.20170257 -0.02765217]
[-0.04303006 -0.42253068 -0.02945417 -0.0817529 0.03569792]
[-0.01433148 0.00066725 -0.08619441 -0.1063433 0.36421112]
[ 0.19718385 0.06653057 0.02880462 -0.31631752 0.04064322]
[ 0.07665874 0.15330013 0.11820727 -0.28386946 -0.06841132]]
-------------------------------------------------------
[[ 0.09817442 0.12635493 0.14153314 -0.13827174 -0.14350587]
[ 0.09484242 0.05155221 0.11429032 -0.04175748 -0.11621833]
[ 0.21802519 0.17491722 0.17653461 -0.2161642 -0.17876485]
[ 0.05461165 -0.01181785 0.31818148 -0.18725258 -0.06083239]
[ 0.03753194 0.04578742 0.30538616 -0.09413831 -0.41238963]
[-0.04687686 0.01701181 0.21276684 -0.02761401 -0.07971509]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]]