x是28*28的手写数字图片,y是该图片对应的数字
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('D:/tensorflow/MNIST_data/mnist',one_hot=True)
batch_size=100
n_batch=mnist.train.num_examples//batch_size
max_time=28
n_inputs=28
lstm_size=100
n_class=10
def Weights_variables():
initial=tf.truncated_normal([lstm_size,n_class],stddev=0.1)
return tf.Variable(initial)
def biases_variables():
initial=tf.constant(0.1,shape=[n_class])
return tf.Variable(initial)
with tf.name_scope('RNN'):
def RNN(X):
inputs=tf.reshape(X,[-1,max_time,n_inputs])
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(lstm_size)
output,final_output=tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
return final_output[1]
with tf.name_scope('input'):
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
hiddlen=RNN(x)
with tf.name_scope('W'):
w=Weights_variables()
with tf.name_scope('b'):
b=biases_variables()
with tf.name_scope('prediction'):
prediction=tf.nn.softmax(tf.matmul(hiddlen,w)+b)
with tf.name_scope('loss'):
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
tf.summary.scalar('loss',loss)
with tf.name_scope('train_step'):
train_step=tf.train.AdamOptimizer(0.0001).minimize(loss)
with tf.name_scope('Accuracy'):
prediction_value=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
Accuracy=tf.reduce_mean(tf.cast(prediction_value,dtype=tf.float32))
tf.summary.scalar('Accuracy',Accuracy)
init=tf.global_variables_initializer()
merged=tf.summary.merge_all()
with tf.Session() as sess:
sess.run(init)
writer_train=tf.summary.FileWriter('logs/train',sess.graph)
writer_test=tf.summary.FileWriter('logs/test',sess.graph)
for batch in range(n_batch-250):
batch_xs,batch_ys=mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
summary=sess.run(merged,feed_dict={x:batch_xs,y:batch_ys})
writer_train.add_summary(summary,batch)
batch_xs,batch_ys=mnist.test.next_batch(batch_size)
summary=sess.run(merged,feed_dict={x:batch_xs,y:batch_ys})
writer_test.add_summary(summary,batch)
其中在RNN输出又加了softmax来预测0~9这10个数字中各种的概率,如下:
可以看到准确率呈现上升趋势,因为demo中训练次数较少,而且为了快点出结果,这里还特意减少了250个batch即
range(n_batch-250):
所以结果不是很好