TensorFlow学习笔记(四):手写数字识别之LSTM网络

长短期记忆(LSTM)是目前循环神经网络最普遍使用的类型,在处理时间序列数据时使用最为频繁。

在 TensorFlow 中,基础的 LSTM 单元格声明为:tf.contrib.rnn.BasicLSTMCell num_units ),其中 num_units 指一个 LSTM 单元格中的单元数,相当于前馈神经网络中的隐藏层神经元个数,前馈神经网络的隐藏层的节点数量等于每一个时间步中一个 LSTM 单元格内 LSTM 单元的 num_units 数量。

在 TensorFlow 中最简单的 RNN 形式是 static_rnn,在 TensorFlow 中定义为:tf.static_rnn cellinputs ),其中 inputs 接受 shape为 [batch_size,input_size] 的张量列表,列表中每一个元素都分别对应网络展开的时间步。

比如 28 x 28 的图像,将网络按 28 个时间步展开,则在每一个时间步中,可以输入一行 28(input_size) 个像素,经过 28 个时间步输入整张图像。给定图像的 batch_size 值,则每一个时间步将分别收到 batch_size 个图像。由 static_rnn 生成的输出是一个形态为 [batch_size,n_hidden] 的张量列表。列表的长度为将网络展开后的时间步数,即每一个时间步输出一个张量。

具体实现

import tensorflow as tf

  1. from tensorflow.contrib import rnn

  2. from tensorflow.examples.tutorials.mnist  import input_data

  3. mnist=input_data.read_data_sets("/tmp/data/"one_hot=True)

  4. time_steps=28

  5. num_units=128

  6. n_input=28

  7. learning_rate=0.001

  8. n_classes=10

  9. batch_size=128

tf.placeholder "float", [Nonetime_stepsn_input] )

    1. tf.placeholder  "float",  [Nonen_classes] )

  1. tf.Variable tf.random_normal ( [num_units,  n_classes] ) )

  2. tf.Variable tf.random_normal ( [n_classes] ) )

#将 shape 为 [batch_size, time_steps, n_input] 的输入转换成,长度为 time_steps  的 shape 为[batch_size, n_inputs] 的张量列表,再输入到 static_rnn。

  1. input tf.unstack time_steps)

#定义LSTM网络

  1. lstm_layer rnn.BasicLSTMCell num_unitsforget_bias=)

  2. outputsrnn.static_rnn lstm_layerinputdtype="float32" )

  1. prediction tf.matmul outputs[-1], w) + b

loss tf.reduce_mean tf.nn.softmax_cross_entropy_with_logits logits=predictionlabels=) )

  1. op tf.train.AdamOptimizer learning_rate=learning_rate ).minimize loss )

  2. correct_prediction tf.equal tf.argmax prediction), tf.argmax y) )

  3. accuracy tf.reduce_mean tf.cast correct_predictiontf.float32 ) )

  1. #running

  2. init tf.global_variables_initializer( )

  3. withtf.Session() as sess:

  4.     sess.run init 

  5.     iter=1

  6.     while iter<800:

  7.         batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size)

  8.         batch_x=batch_x.reshape((batch_size,time_steps,n_input))

  9.         sess.run(opt,feed_dict={x:batch_x,y:batch_y})

  10.         iter=iter+1

        #calculating test accuracy

  1.         test_data mnist.test.images[:128].reshape ( (-1time_stepsn_input ) )

  2.         test_label mnist.test.labels[:128]

  3.         print "Testing Accuracy:",sess.run accuracy,feed_dict={x:test_data,y:test_label} ) )

最终准确率为:

Testing Accuracy: 99.21%。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值