import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
def build_graph():
images = tf.placeholder(tf.float32, [None, 28, 28])
labels = tf.placeholder(tf.float32, [None, 10])
weight = {
'first': tf.Variable(tf.random_normal([28, 128])),
'last': tf.Variable(tf.random_normal([128, 10]))
}
bias = {
'first': tf.Variable(tf.constant(0.1, shape=[128, ])),
'last': tf.Variable(tf.constant(0.1, shape=[10, ]))
}
X = tf.reshape(images, shape=[-1, 28])
X_in = tf.matmul(X, weight['first']) + bias['first']
X_in = tf.reshape(X_in, shape=[-1, 28, 128])
# lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=128)
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
# lstm_cell = tf.contrib.rnn.BasicLSTMCell(128, forget_bias=1.0, state_is_tuple=True)
init_state = lstm_cell.zero_state(128, dtype=tf.float32)
outputs,final_state= tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)
# results = tf.matmul(final_state[1], weight['last']) + bias['last']
results = tf.layers.dense(outputs[:, -1, :], 10)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=results)
global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)
train_op = tf.train.AdamOptimizer(0.01).minimize(loss,global_step = global_step)
correct_pred = tf.equal(tf.argmax(results, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
return {'images': images,
'labels': labels,
'train_op': train_op,
'global_step':global_step,
'accuracy': accuracy}
if __name__=='__main__':
mnist = input_data.read_data_sets('/home/ly/MNIST_data', one_hot=True)
with tf.Session() as sess:
graph = build_graph()
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
batch_x, batch_y = mnist.train.next_batch(128)
batch_x = batch_x.reshape([128, 28, 28])
feed_dict = {graph['images']: batch_x,
graph['labels']: batch_y}
_, step, accuracy = sess.run(
[graph['train_op'], graph['global_step'], graph['accuracy']],
feed_dict=feed_dict)
if step % 50 == 0:
print(accuracy)
if step > 1200:
break
except tf.errors.OutOfRangeError:
print('==================Train Finished================')
finally:
coord.request_stop()
coord.join(threads)
RNN对MNIST数据集的识别
最新推荐文章于 2022-01-15 14:49:40 发布