长短期记忆(LSTM)是目前循环神经网络最普遍使用的类型,在处理时间序列数据时使用最为频繁。
在 TensorFlow 中,基础的 LSTM 单元格声明为:tf.contrib.rnn.BasicLSTMCell ( num_units ),其中 num_units 指一个 LSTM 单元格中的单元数,相当于前馈神经网络中的隐藏层神经元个数,前馈神经网络的隐藏层的节点数量等于每一个时间步中一个 LSTM 单元格内 LSTM 单元的 num_units 数量。
在 TensorFlow 中最简单的 RNN 形式是 static_rnn,在 TensorFlow 中定义为:tf.static_rnn ( cell, inputs ),其中 inputs 接受 shape为 [batch_size,input_size] 的张量列表,列表中每一个元素都分别对应网络展开的时间步。
比如 28 x 28 的图像,将网络按 28 个时间步展开,则在每一个时间步中,可以输入一行 28(input_size) 个像素,经过 28 个时间步输入整张图像。给定图像的 batch_size 值,则每一个时间步将分别收到 batch_size 个图像。由 static_rnn 生成的输出是一个形态为 [batch_size,n_hidden] 的张量列表。列表的长度为将网络展开后的时间步数,即每一个时间步输出一个张量。
具体实现
import tensorflow as tf
-
from tensorflow.contrib import rnn
-
from tensorflow.examples.tutorials.mnist import input_data
-
mnist=input_data.read_data_sets("/tmp/data/", one_hot=True)
-
time_steps=28
-
num_units=128
-
n_input=28
-
learning_rate=0.001
-
n_classes=10
-
batch_size=128
x = tf.placeholder ( "float", [None, time_steps, n_input] )
-
-
y = tf.placeholder ( "float", [None, n_classes] )
-
-
w = tf.Variable ( tf.random_normal ( [num_units, n_classes] ) )
-
b = tf.Variable ( tf.random_normal ( [n_classes] ) )
#将 shape 为 [batch_size, time_steps, n_input] 的输入转换成,长度为 time_steps 的 shape 为[batch_size, n_inputs] 的张量列表,再输入到 static_rnn。
-
input = tf.unstack ( x , time_steps, 1 )
#定义LSTM网络
-
lstm_layer = rnn.BasicLSTMCell ( num_units, forget_bias=1 )
-
outputs, _ = rnn.static_rnn ( lstm_layer, input, dtype="float32" )
-
prediction = tf.matmul ( outputs[-1], w) + b
loss = tf.reduce_mean ( tf.nn.softmax_cross_entropy_with_logits ( logits=prediction, labels=y ) )
-
op = tf.train.AdamOptimizer ( learning_rate=learning_rate ).minimize ( loss )
-
correct_prediction = tf.equal ( tf.argmax ( prediction, 1 ), tf.argmax ( y, 1 ) )
-
accuracy = tf.reduce_mean ( tf.cast ( correct_prediction, tf.float32 ) )
-
#running
-
init = tf.global_variables_initializer( )
-
withtf.Session() as sess:
-
sess.run ( init )
-
iter=1
-
while iter<800:
-
batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size)
-
batch_x=batch_x.reshape((batch_size,time_steps,n_input))
-
sess.run(opt,feed_dict={x:batch_x,y:batch_y})
-
iter=iter+1
#calculating test accuracy
-
test_data = mnist.test.images[:128].reshape ( (-1, time_steps, n_input ) )
-
test_label = mnist.test.labels[:128]
-
print ( "Testing Accuracy:",sess.run ( accuracy,feed_dict={x:test_data,y:test_label} ) )
最终准确率为:
Testing Accuracy: 99.21%。