import tensorflow as tf
import numpy as np
# 创建输入数据
cell = tf.contrib.rnn.BasicLSTMCell(num_units=4, state_is_tuple=True)
X = tf.placeholder(tf.float32,(2,10,8))
X_lengths = tf.placeholder(tf.float32)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float32,
sequence_length=X_lengths,
inputs=X)
lable = tf.ones((2,10,4))
loss = lable - outputs
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
outputs2, last_states2 = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float32,
sequence_length=X_lengths[:1],
inputs=X[:1])
with tf.Session() as sess:
X1 = np.ones((2, 10, 8))
# 第二个example长度为6
X_lengths1 = [10, 6]
# X2 =np.ones((1,10,8))
# X_lengths2 = [10]
sess.run(tf.global_variables_initializer())
for i in range(10):
sess.run(optimizer,feed_dict={X:X1,X_lengths:X_lengths1})
outputs, last_states,outputs2,last_states2 = sess.run((outputs, last_states,outputs2,last_states2),feed_dict={X:X1,X_lengths:X_lengths1})
print(last_states)
print(outputs)
print(last_states2)
print(outputs2)
print(outputs2-outputs[0])
输出为
LSTMStateTuple(c=array([[3.6967206 , 0.36333948, 2.6907382 , 0.7764288 ],
[2.9159408 , 0.34533507, 2.3552992 , 0.7168741 ]], dtype=float32), h=array([[0.83917683, 0.16514859, 0.82227004, 0.2351862 ],
[0.83666867, 0.15811422, 0.8160259 , 0.22364864]], dtype=float32))
[[[0.5070681 0.08897948 0.5242556 0.1298186 ]
[0.72878003 0.11847612 0.7119349 0.16895662]
[0.8013882 0.13465777 0.77578974 0.19117852]
[0.8250746 0.14580028 0.80016226 0.20640413]
[0.8334302 0.15326948 0.8107919 0.21670982]
[0.83666867 0.15811422 0.8160259 0.22364864]
[0.83804226 0.16120268 0.81887186 0.22833543]
[0.8386776 0.16315424 0.8205446 0.2315192 ]
[0.8389986 0.16438086 0.8215884 0.23369396]
[0.83917683 0.16514859 0.82227004 0.2351862 ]]
[[0.5070681 0.08897948 0.5242556 0.1298186 ]
[0.72878003 0.11847612 0.7119349 0.16895662]
[0.8013882 0.13465777 0.77578974 0.19117852]
[0.8250746 0.14580028 0.80016226 0.20640413]
[0.8334302 0.15326948 0.8107919 0.21670982]
[0.83666867 0.15811422 0.8160259 0.22364864]
[0. 0. 0. 0. ]
[0. 0. 0. 0. ]
[0. 0. 0. 0. ]
[0. 0. 0. 0. ]]]
LSTMStateTuple(c=array([[3.6967204 , 0.36333948, 2.6907384 , 0.7764287 ]], dtype=float32), h=array([[0.8391768 , 0.16514859, 0.8222699 , 0.23518619]], dtype=float32))
[[[0.5070681 0.08897948 0.5242555 0.12981859]
[0.72878003 0.11847612 0.7119349 0.16895662]
[0.8013882 0.13465776 0.77578974 0.19117847]
[0.8250748 0.14580028 0.8001623 0.20640409]
[0.8334301 0.15326953 0.8107919 0.21670981]
[0.8366687 0.15811422 0.816026 0.22364862]
[0.83804226 0.16120267 0.818872 0.22833538]
[0.8386776 0.16315423 0.8205446 0.2315192 ]
[0.8389986 0.16438086 0.82158846 0.23369391]
[0.8391768 0.16514859 0.8222699 0.23518619]]]
[[[ 0.0000000e+00 0.0000000e+00 -5.9604645e-08 -1.4901161e-08]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 -1.4901161e-08 0.0000000e+00 -4.4703484e-08]
[ 1.7881393e-07 0.0000000e+00 5.9604645e-08 -4.4703484e-08]
[-5.9604645e-08 4.4703484e-08 0.0000000e+00 -1.4901161e-08]
[ 5.9604645e-08 0.0000000e+00 5.9604645e-08 -1.4901161e-08]
[ 0.0000000e+00 -1.4901161e-08 1.1920929e-07 -4.4703484e-08]
[ 0.0000000e+00 -1.4901161e-08 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 5.9604645e-08 -4.4703484e-08]
[-5.9604645e-08 0.0000000e+00 -1.1920929e-07 -1.4901161e-08]]]
创建两个lstm
第一个输入为[2,10,8]
第二个的输入为第一个输入的第一维
前后两次调用结果一样
rnn权重共享 因为cell为同一个cell