input为[2,10,8]的数组
batch = 2
steps = 10
为了变长 设置第二个batch的长度为6
import tensorflow as tf
import numpy as np
# 创建输入数据
cell = tf.contrib.rnn.BasicLSTMCell(num_units=4, state_is_tuple=True)
X = tf.placeholder(tf.float32,(2,10,8))
X_lengths = tf.placeholder(tf.float32)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float32,
sequence_length=X_lengths,
inputs=X)
with tf.Session() as sess:
X1 = np.random.randn(2, 10, 8)
# 第二个example长度为6
X_lengths1 = [10, 6]
sess.run(tf.global_variables_initializer())
outputs, last_states = sess.run((outputs,last_states),feed_dict={X:X1,X_lengths:X_lengths1})
print(last_states)
print(outputs)
输出为
last state:[2,4]
LSTMStateTuple(c=array([[-0.43867487, 0.1597187 , -0.0061222 , 0.21639484],
[-0.34291208, 0.20014592, 0.5361769 , 0.9636594 ]],
dtype=float32), h=array([[-0.14935654, 0.08637159, -0.00271731, 0.14839597],
[-0.17922135, 0.12582389, 0.25058392, 0.6113648 ]],
dtype=float32))
outputs:[2,10,4]
[[[-0.02683123 0.02625443 0.16520582 0.13091983]
[ 0.03294514 -0.0934423 0.19723308 0.16243443]
[ 0.2861453 0.0485843 0.1357825 0.22259377]
[-0.15690039 0.24221653 -0.09397997 0.02853295]
[-0.35596934 0.21729332 -0.06792817 -0.10537033]
[-0.1089786 0.02690388 -0.0379463 -0.09407488]
[-0.05858612 0.02765152 0.10222299 -0.01180249]
[-0.25069004 0.10124367 0.3203107 -0.00553809]
[-0.12913188 0.15587549 -0.10340141 -0.03349639]
[-0.14935654 0.08637159 -0.00271731 0.14839597]]
[[ 0.02649872 -0.06729827 -0.2074864 -0.09794931]
[-0.178599 0.07660254 -0.22099023 -0.19292895]
[ 0.0912771 -0.10909905 -0.17696406 -0.05781336]
[ 0.17243923 -0.17442465 0.03783391 0.17704639]
[ 0.10158884 -0.150579 0.21208954 0.47918105]
[-0.17922135 0.12582389 0.25058392 0.6113648 ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]]