教程地址
带2个隐藏层的LSTM
import tensorflow.compat.v1 as tf
import tensorflow as tf2
tf.disable_v2_behavior()
import numpy as np
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
BATCH_SIZE = 128
N_STEPS = 28
N_INPUTS = 28
N_CLASSES=10
N_HIDDEN_UNITS=128
LR = 1e-3
TRAINING_ITERS=3e5
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train / 255).reshape((-1,N_STEPS,N_INPUTS))
x_test = (x_test / 255).reshape((-1,N_STEPS,N_INPUTS))
y_train = np.eye(10)[y_train]
y_test = np.eye(10)[y_test]
x=tf.placeholder(tf.float32,[None,N_STEPS,N_INPUTS])
y=tf.placeholder(tf.float32,[None,N_CLASSES])
W={'in':tf.Variable(tf.random_normal([N_INPUTS,N_HIDDEN_UNITS])),
'out':tf.Variable(tf.random_normal([N_HIDDEN_UNITS,N_CLASSES]))}
b={'in':tf.Variable(tf.constant(.1,shape=[N_HIDDEN_UNITS,])),
'out':tf.Variable(tf.constant(.1,shape=[N_CLASSES,]))}
def RNN(X,weights,b):
X=tf.reshape(X,(-1,N_INPUTS))
X_in=tf.matmul(X,weights['in'])+b['in']
X_in=tf.reshape(X_in,(-1,N_STEPS,N_HIDDEN_UNITS))
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(N_HIDDEN_UNITS,
forget_bias=1,state_is_tuple=True)
outputs, states = tf.nn.dynamic_rnn(lstm_cell, X_in, dtype=tf.float32)
results=tf.matmul(states[1],weights['out'],)+b['out']
return results
pred=RNN(x,W,b)
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
train_op=tf.train.AdamOptimizer(LR).minimize(cost)
correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
results=[]
with tf.Session() as se:
se.run(tf.global_variables_initializer())
step=0
for step in range(0,int(TRAINING_ITERS),BATCH_SIZE):
random_index=np.random.choice(x_train.shape[0], BATCH_SIZE, replace=False)
batch_xs,batch_ys=x_train[random_index],y_train[random_index]
se.run(train_op,feed_dict={x:batch_xs,y:batch_ys})
sep_=1000
if step%sep_==0:
acc = se.run(accuracy, feed_dict={x: x_test, y: y_test})
print(step,acc)
results.append(acc)
plt.plot([sep_ * i for i in range(len(results))], results)
y_major_locator = plt.MultipleLocator(.1)
ax = plt.gca()
ax.yaxis.set_major_locator(y_major_locator)
plt.ylim(0, 1)
plt.show()
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/ceb822597af3ef5f430ee319b3566653.png)